mistralrs_core/utils/
progress.rs1use indicatif::{
2 MultiProgress, ProgressBar, ProgressBarIter, ProgressDrawTarget, ProgressIterator,
3 ProgressStyle,
4};
5use mistralrs_quant::get_immediate_isq;
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use rayon::prelude::*;
8use std::iter::Iterator;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use tqdm::Iter;
11
12static PROGRESS_SUPPRESS_COUNT: AtomicUsize = AtomicUsize::new(0);
13
14pub struct ProgressScopeGuard {
16 suppressed: bool,
17}
18
19impl ProgressScopeGuard {
20 pub fn new(silent: bool) -> Self {
21 if silent {
22 PROGRESS_SUPPRESS_COUNT.fetch_add(1, Ordering::SeqCst);
23 }
24 Self { suppressed: silent }
25 }
26}
27
28impl Drop for ProgressScopeGuard {
29 fn drop(&mut self) {
30 if self.suppressed {
31 PROGRESS_SUPPRESS_COUNT.fetch_sub(1, Ordering::SeqCst);
32 }
33 }
34}
35
36#[inline]
37pub fn progress_suppressed() -> bool {
38 PROGRESS_SUPPRESS_COUNT.load(Ordering::SeqCst) > 0
39}
40
41#[inline]
42pub fn configure_progress_bar(bar: &ProgressBar) {
43 if progress_suppressed() {
44 bar.set_draw_target(ProgressDrawTarget::hidden());
45 }
46}
47
48pub fn new_multi_progress() -> MultiProgress {
49 let multi = MultiProgress::new();
50 if progress_suppressed() {
51 multi.set_draw_target(ProgressDrawTarget::hidden());
52 }
53 multi
54}
55
56pub trait IterWithProgress<'a, T>: Iterator<Item = T> + 'a {
60 fn with_progress(self, is_silent: bool) -> Box<dyn Iterator<Item = T> + 'a>
61 where
62 Self: Sized,
63 {
64 if is_silent {
66 Box::new(self)
67 } else {
68 Box::new(self.tqdm())
69 }
70 }
71}
72
73impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {}
74
75pub struct NiceProgressBar<'a, T: ExactSizeIterator, const COLOR: char = 'b'>(
78 pub T,
79 pub &'static str,
80 pub &'a MultiProgress,
81);
82
83impl<T: ExactSizeIterator, const COLOR: char> IntoIterator for NiceProgressBar<'_, T, COLOR> {
84 type IntoIter = ProgressBarIter<T>;
85 type Item = T::Item;
86
87 fn into_iter(self) -> Self::IntoIter {
88 let color = match COLOR {
89 'b' => "blue",
90 'g' => "green",
91 'r' => "red",
92 other => panic!("Color char `{other}` not supported"),
93 };
94 let bar = ProgressBar::new(self.0.len() as u64);
95 configure_progress_bar(&bar);
96 bar.set_style(
97 ProgressStyle::default_bar()
98 .template(&format!(
99 "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
100 self.1
101 ))
102 .unwrap()
103 .progress_chars("#>-"),
104 );
105
106 self.2.add(bar.clone());
108
109 self.0.progress_with(bar)
110 }
111}
112
113pub struct ParProgress<I> {
115 iter: I,
116 bar: ProgressBar,
117}
118
119impl<I> ParallelIterator for ParProgress<I>
120where
121 I: ParallelIterator,
122 I::Item: Send,
123{
124 type Item = I::Item;
125
126 fn drive_unindexed<C>(self, consumer: C) -> C::Result
127 where
128 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
129 {
130 let bar = self.bar.clone();
131 let iter = self.iter.map(move |item| {
132 bar.inc(1);
133 item
134 });
135 iter.drive_unindexed(consumer)
136 }
137}
138
139impl<I> IndexedParallelIterator for ParProgress<I>
140where
141 I: IndexedParallelIterator,
142 I::Item: Send,
143{
144 fn len(&self) -> usize {
145 self.iter.len()
146 }
147
148 fn drive<C>(self, consumer: C) -> C::Result
149 where
150 C: rayon::iter::plumbing::Consumer<Self::Item>,
151 {
152 let bar = self.bar.clone();
153 let iter = self.iter.map(move |item| {
154 bar.inc(1);
155 item
156 });
157 iter.drive(consumer)
158 }
159
160 fn with_producer<CB>(self, callback: CB) -> CB::Output
161 where
162 CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
163 {
164 let bar = self.bar.clone();
165 let iter = self.iter.map(move |item| {
166 bar.inc(1);
167 item
168 });
169 iter.with_producer(callback)
170 }
171}
172
173impl<'a, T, const COLOR: char> IntoParallelIterator for NiceProgressBar<'a, T, COLOR>
174where
175 T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
176 <T as IntoParallelIterator>::Item: Send + 'a,
177 T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
178 + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
179 + Send,
180{
181 type Iter = ParProgress<T::Iter>;
182 type Item = <T as IntoParallelIterator>::Item;
183
184 fn into_par_iter(self) -> Self::Iter {
185 let color = match COLOR {
186 'b' => "blue",
187 'g' => "green",
188 'r' => "red",
189 other => panic!("Color char `{other}` not supported"),
190 };
191 let bar = ProgressBar::new(self.0.len() as u64);
192 configure_progress_bar(&bar);
193 bar.set_style(
194 ProgressStyle::default_bar()
195 .template(&format!(
196 "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
197 self.1
198 ))
199 .unwrap()
200 .progress_chars("#>-"),
201 );
202 self.2.add(bar.clone());
203 ParProgress {
204 iter: self.0.into_par_iter(),
205 bar,
206 }
207 }
208}
209
210impl<'a, T, const COLOR: char> NiceProgressBar<'a, T, COLOR>
211where
212 T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
213 <T as IntoParallelIterator>::Item: Send + 'a,
214 T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
215 + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
216 + Send,
217 T: IntoParallelIterator<Item = <T as Iterator>::Item>,
218{
219 pub fn run<F, U>(self, _is_parallel: bool, f: F) -> candle_core::Result<Vec<U>>
224 where
225 F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
226 U: Send,
227 {
228 self.into_iter().map(f).collect()
234 }
235
236 pub fn par_iter_if_isq<F, U>(self, f: F) -> candle_core::Result<Vec<U>>
240 where
241 F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
242 U: Send,
243 {
244 self.run(get_immediate_isq().is_some_and(|x| x.ty.is_some()), f)
245 }
246}