mistralrs_core/utils/
progress.rs

1use indicatif::{
2    MultiProgress, ProgressBar, ProgressBarIter, ProgressDrawTarget, ProgressIterator,
3    ProgressStyle,
4};
5use mistralrs_quant::get_immediate_isq;
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use rayon::prelude::*;
8use std::iter::Iterator;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use tqdm::Iter;
11
12static PROGRESS_SUPPRESS_COUNT: AtomicUsize = AtomicUsize::new(0);
13
14/// RAII guard that suppresses progress bar drawing while it is alive.
15pub struct ProgressScopeGuard {
16    suppressed: bool,
17}
18
19impl ProgressScopeGuard {
20    pub fn new(silent: bool) -> Self {
21        if silent {
22            PROGRESS_SUPPRESS_COUNT.fetch_add(1, Ordering::SeqCst);
23        }
24        Self { suppressed: silent }
25    }
26}
27
28impl Drop for ProgressScopeGuard {
29    fn drop(&mut self) {
30        if self.suppressed {
31            PROGRESS_SUPPRESS_COUNT.fetch_sub(1, Ordering::SeqCst);
32        }
33    }
34}
35
36#[inline]
37pub fn progress_suppressed() -> bool {
38    PROGRESS_SUPPRESS_COUNT.load(Ordering::SeqCst) > 0
39}
40
41#[inline]
42pub fn configure_progress_bar(bar: &ProgressBar) {
43    if progress_suppressed() {
44        bar.set_draw_target(ProgressDrawTarget::hidden());
45    }
46}
47
48pub fn new_multi_progress() -> MultiProgress {
49    let multi = MultiProgress::new();
50    if progress_suppressed() {
51        multi.set_draw_target(ProgressDrawTarget::hidden());
52    }
53    multi
54}
55
56// Optionally display a progress bar via the `tqdm` crate:
57// Usage: `iter.with_progress(true)`
58// Similar to the `iter.tqdm()` feature except this supports opt-in via parameter.
59pub trait IterWithProgress<'a, T>: Iterator<Item = T> + 'a {
60    fn with_progress(self, is_silent: bool) -> Box<dyn Iterator<Item = T> + 'a>
61    where
62        Self: Sized,
63    {
64        // TODO: Should `is_silent` instead be referenced as a global read-only state? (`AtomicBool`)
65        if is_silent {
66            Box::new(self)
67        } else {
68            Box::new(self.tqdm())
69        }
70    }
71}
72
73impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {}
74
75/// Nice progress bar with over an iterator and a message.
76/// COLOR is one of r,g,b
77pub struct NiceProgressBar<'a, T: ExactSizeIterator, const COLOR: char = 'b'>(
78    pub T,
79    pub &'static str,
80    pub &'a MultiProgress,
81);
82
83impl<T: ExactSizeIterator, const COLOR: char> IntoIterator for NiceProgressBar<'_, T, COLOR> {
84    type IntoIter = ProgressBarIter<T>;
85    type Item = T::Item;
86
87    fn into_iter(self) -> Self::IntoIter {
88        let color = match COLOR {
89            'b' => "blue",
90            'g' => "green",
91            'r' => "red",
92            other => panic!("Color char `{other}` not supported"),
93        };
94        let bar = ProgressBar::new(self.0.len() as u64);
95        configure_progress_bar(&bar);
96        bar.set_style(
97            ProgressStyle::default_bar()
98                .template(&format!(
99                    "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
100                    self.1
101                ))
102                .unwrap()
103                .progress_chars("#>-"),
104        );
105
106        // Add to the multi progress
107        self.2.add(bar.clone());
108
109        self.0.progress_with(bar)
110    }
111}
112
113/// Parallel iterator with progress reporting.
114pub struct ParProgress<I> {
115    iter: I,
116    bar: ProgressBar,
117}
118
119impl<I> ParallelIterator for ParProgress<I>
120where
121    I: ParallelIterator,
122    I::Item: Send,
123{
124    type Item = I::Item;
125
126    fn drive_unindexed<C>(self, consumer: C) -> C::Result
127    where
128        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
129    {
130        let bar = self.bar.clone();
131        let iter = self.iter.map(move |item| {
132            bar.inc(1);
133            item
134        });
135        iter.drive_unindexed(consumer)
136    }
137}
138
139impl<I> IndexedParallelIterator for ParProgress<I>
140where
141    I: IndexedParallelIterator,
142    I::Item: Send,
143{
144    fn len(&self) -> usize {
145        self.iter.len()
146    }
147
148    fn drive<C>(self, consumer: C) -> C::Result
149    where
150        C: rayon::iter::plumbing::Consumer<Self::Item>,
151    {
152        let bar = self.bar.clone();
153        let iter = self.iter.map(move |item| {
154            bar.inc(1);
155            item
156        });
157        iter.drive(consumer)
158    }
159
160    fn with_producer<CB>(self, callback: CB) -> CB::Output
161    where
162        CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
163    {
164        let bar = self.bar.clone();
165        let iter = self.iter.map(move |item| {
166            bar.inc(1);
167            item
168        });
169        iter.with_producer(callback)
170    }
171}
172
173impl<'a, T, const COLOR: char> IntoParallelIterator for NiceProgressBar<'a, T, COLOR>
174where
175    T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
176    <T as IntoParallelIterator>::Item: Send + 'a,
177    T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
178        + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
179        + Send,
180{
181    type Iter = ParProgress<T::Iter>;
182    type Item = <T as IntoParallelIterator>::Item;
183
184    fn into_par_iter(self) -> Self::Iter {
185        let color = match COLOR {
186            'b' => "blue",
187            'g' => "green",
188            'r' => "red",
189            other => panic!("Color char `{other}` not supported"),
190        };
191        let bar = ProgressBar::new(self.0.len() as u64);
192        configure_progress_bar(&bar);
193        bar.set_style(
194            ProgressStyle::default_bar()
195                .template(&format!(
196                    "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
197                    self.1
198                ))
199                .unwrap()
200                .progress_chars("#>-"),
201        );
202        self.2.add(bar.clone());
203        ParProgress {
204            iter: self.0.into_par_iter(),
205            bar,
206        }
207    }
208}
209
210impl<'a, T, const COLOR: char> NiceProgressBar<'a, T, COLOR>
211where
212    T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
213    <T as IntoParallelIterator>::Item: Send + 'a,
214    T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
215        + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
216        + Send,
217    T: IntoParallelIterator<Item = <T as Iterator>::Item>,
218{
219    /// Applies the given closure over the items, optionally in parallel, and collects the results.
220    ///
221    /// - `is_parallel`: If true, uses Rayon parallel iteration; otherwise uses sequential iteration.
222    /// - `f`: A closure to apply to each item.
223    pub fn run<F, U>(self, _is_parallel: bool, f: F) -> candle_core::Result<Vec<U>>
224    where
225        F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
226        U: Send,
227    {
228        // if is_parallel {
229        //     self.into_par_iter().map(f).collect()
230        // } else {
231        //     self.into_iter().map(f).collect()
232        // }
233        self.into_iter().map(f).collect()
234    }
235
236    /// Applies the given closure over the items, optionally in parallel, and collects the results.
237    ///
238    /// - `f`: A closure to apply to each item.
239    pub fn par_iter_if_isq<F, U>(self, f: F) -> candle_core::Result<Vec<U>>
240    where
241        F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
242        U: Send,
243    {
244        self.run(get_immediate_isq().is_some_and(|x| x.ty.is_some()), f)
245    }
246}