mistralrs_core/utils/
progress.rs

1use indicatif::{MultiProgress, ProgressBar, ProgressBarIter, ProgressIterator, ProgressStyle};
2use mistralrs_quant::get_immediate_isq;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4use rayon::prelude::*;
5use std::iter::Iterator;
6use tqdm::Iter;
7
8// Optionally display a progress bar via the `tqdm` crate:
9// Usage: `iter.with_progress(true)`
10// Similar to the `iter.tqdm()` feature except this supports opt-in via parameter.
11pub trait IterWithProgress<'a, T>: Iterator<Item = T> + 'a {
12    fn with_progress(self, is_silent: bool) -> Box<dyn Iterator<Item = T> + 'a>
13    where
14        Self: Sized,
15    {
16        // TODO: Should `is_silent` instead be referenced as a global read-only state? (`AtomicBool`)
17        if is_silent {
18            Box::new(self)
19        } else {
20            Box::new(self.tqdm())
21        }
22    }
23}
24
25impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {}
26
27/// Nice progress bar with over an iterator and a message.
28/// COLOR is one of r,g,b
29pub struct NiceProgressBar<'a, T: ExactSizeIterator, const COLOR: char = 'b'>(
30    pub T,
31    pub &'static str,
32    pub &'a MultiProgress,
33);
34
35impl<T: ExactSizeIterator, const COLOR: char> IntoIterator for NiceProgressBar<'_, T, COLOR> {
36    type IntoIter = ProgressBarIter<T>;
37    type Item = T::Item;
38
39    fn into_iter(self) -> Self::IntoIter {
40        let color = match COLOR {
41            'b' => "blue",
42            'g' => "green",
43            'r' => "red",
44            other => panic!("Color char `{other}` not supported"),
45        };
46        let bar = ProgressBar::new(self.0.len() as u64);
47        bar.set_style(
48            ProgressStyle::default_bar()
49                .template(&format!(
50                    "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
51                    self.1
52                ))
53                .unwrap()
54                .progress_chars("#>-"),
55        );
56
57        // Add to the multi progress
58        self.2.add(bar.clone());
59
60        self.0.progress_with(bar)
61    }
62}
63
64/// Parallel iterator with progress reporting.
65pub struct ParProgress<I> {
66    iter: I,
67    bar: ProgressBar,
68}
69
70impl<I> ParallelIterator for ParProgress<I>
71where
72    I: ParallelIterator,
73    I::Item: Send,
74{
75    type Item = I::Item;
76
77    fn drive_unindexed<C>(self, consumer: C) -> C::Result
78    where
79        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
80    {
81        let bar = self.bar.clone();
82        let iter = self.iter.map(move |item| {
83            bar.inc(1);
84            item
85        });
86        iter.drive_unindexed(consumer)
87    }
88}
89
90impl<I> IndexedParallelIterator for ParProgress<I>
91where
92    I: IndexedParallelIterator,
93    I::Item: Send,
94{
95    fn len(&self) -> usize {
96        self.iter.len()
97    }
98
99    fn drive<C>(self, consumer: C) -> C::Result
100    where
101        C: rayon::iter::plumbing::Consumer<Self::Item>,
102    {
103        let bar = self.bar.clone();
104        let iter = self.iter.map(move |item| {
105            bar.inc(1);
106            item
107        });
108        iter.drive(consumer)
109    }
110
111    fn with_producer<CB>(self, callback: CB) -> CB::Output
112    where
113        CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
114    {
115        let bar = self.bar.clone();
116        let iter = self.iter.map(move |item| {
117            bar.inc(1);
118            item
119        });
120        iter.with_producer(callback)
121    }
122}
123
124impl<'a, T, const COLOR: char> IntoParallelIterator for NiceProgressBar<'a, T, COLOR>
125where
126    T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
127    <T as IntoParallelIterator>::Item: Send + 'a,
128    T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
129        + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
130        + Send,
131{
132    type Iter = ParProgress<T::Iter>;
133    type Item = <T as IntoParallelIterator>::Item;
134
135    fn into_par_iter(self) -> Self::Iter {
136        let color = match COLOR {
137            'b' => "blue",
138            'g' => "green",
139            'r' => "red",
140            other => panic!("Color char `{other}` not supported"),
141        };
142        let bar = ProgressBar::new(self.0.len() as u64);
143        bar.set_style(
144            ProgressStyle::default_bar()
145                .template(&format!(
146                    "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
147                    self.1
148                ))
149                .unwrap()
150                .progress_chars("#>-"),
151        );
152        self.2.add(bar.clone());
153        ParProgress {
154            iter: self.0.into_par_iter(),
155            bar,
156        }
157    }
158}
159
160impl<'a, T, const COLOR: char> NiceProgressBar<'a, T, COLOR>
161where
162    T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
163    <T as IntoParallelIterator>::Item: Send + 'a,
164    T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
165        + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
166        + Send,
167    T: IntoParallelIterator<Item = <T as Iterator>::Item>,
168{
169    /// Applies the given closure over the items, optionally in parallel, and collects the results.
170    ///
171    /// - `is_parallel`: If true, uses Rayon parallel iteration; otherwise uses sequential iteration.
172    /// - `f`: A closure to apply to each item.
173    pub fn run<F, U>(self, _is_parallel: bool, f: F) -> candle_core::Result<Vec<U>>
174    where
175        F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
176        U: Send,
177    {
178        // if is_parallel {
179        //     self.into_par_iter().map(f).collect()
180        // } else {
181        //     self.into_iter().map(f).collect()
182        // }
183        self.into_iter().map(f).collect()
184    }
185
186    /// Applies the given closure over the items, optionally in parallel, and collects the results.
187    ///
188    /// - `f`: A closure to apply to each item.
189    pub fn par_iter_if_isq<F, U>(self, f: F) -> candle_core::Result<Vec<U>>
190    where
191        F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
192        U: Send,
193    {
194        self.run(get_immediate_isq().is_some_and(|x| x.ty.is_some()), f)
195    }
196}