mistralrs_core/utils/
progress.rs1use indicatif::{MultiProgress, ProgressBar, ProgressBarIter, ProgressIterator, ProgressStyle};
2use mistralrs_quant::get_immediate_isq;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4use rayon::prelude::*;
5use std::iter::Iterator;
6use tqdm::Iter;
7
8pub trait IterWithProgress<'a, T>: Iterator<Item = T> + 'a {
12 fn with_progress(self, is_silent: bool) -> Box<dyn Iterator<Item = T> + 'a>
13 where
14 Self: Sized,
15 {
16 if is_silent {
18 Box::new(self)
19 } else {
20 Box::new(self.tqdm())
21 }
22 }
23}
24
25impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {}
26
27pub struct NiceProgressBar<'a, T: ExactSizeIterator, const COLOR: char = 'b'>(
30 pub T,
31 pub &'static str,
32 pub &'a MultiProgress,
33);
34
35impl<T: ExactSizeIterator, const COLOR: char> IntoIterator for NiceProgressBar<'_, T, COLOR> {
36 type IntoIter = ProgressBarIter<T>;
37 type Item = T::Item;
38
39 fn into_iter(self) -> Self::IntoIter {
40 let color = match COLOR {
41 'b' => "blue",
42 'g' => "green",
43 'r' => "red",
44 other => panic!("Color char `{other}` not supported"),
45 };
46 let bar = ProgressBar::new(self.0.len() as u64);
47 bar.set_style(
48 ProgressStyle::default_bar()
49 .template(&format!(
50 "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
51 self.1
52 ))
53 .unwrap()
54 .progress_chars("#>-"),
55 );
56
57 self.2.add(bar.clone());
59
60 self.0.progress_with(bar)
61 }
62}
63
64pub struct ParProgress<I> {
66 iter: I,
67 bar: ProgressBar,
68}
69
70impl<I> ParallelIterator for ParProgress<I>
71where
72 I: ParallelIterator,
73 I::Item: Send,
74{
75 type Item = I::Item;
76
77 fn drive_unindexed<C>(self, consumer: C) -> C::Result
78 where
79 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
80 {
81 let bar = self.bar.clone();
82 let iter = self.iter.map(move |item| {
83 bar.inc(1);
84 item
85 });
86 iter.drive_unindexed(consumer)
87 }
88}
89
90impl<I> IndexedParallelIterator for ParProgress<I>
91where
92 I: IndexedParallelIterator,
93 I::Item: Send,
94{
95 fn len(&self) -> usize {
96 self.iter.len()
97 }
98
99 fn drive<C>(self, consumer: C) -> C::Result
100 where
101 C: rayon::iter::plumbing::Consumer<Self::Item>,
102 {
103 let bar = self.bar.clone();
104 let iter = self.iter.map(move |item| {
105 bar.inc(1);
106 item
107 });
108 iter.drive(consumer)
109 }
110
111 fn with_producer<CB>(self, callback: CB) -> CB::Output
112 where
113 CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
114 {
115 let bar = self.bar.clone();
116 let iter = self.iter.map(move |item| {
117 bar.inc(1);
118 item
119 });
120 iter.with_producer(callback)
121 }
122}
123
124impl<'a, T, const COLOR: char> IntoParallelIterator for NiceProgressBar<'a, T, COLOR>
125where
126 T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
127 <T as IntoParallelIterator>::Item: Send + 'a,
128 T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
129 + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
130 + Send,
131{
132 type Iter = ParProgress<T::Iter>;
133 type Item = <T as IntoParallelIterator>::Item;
134
135 fn into_par_iter(self) -> Self::Iter {
136 let color = match COLOR {
137 'b' => "blue",
138 'g' => "green",
139 'r' => "red",
140 other => panic!("Color char `{other}` not supported"),
141 };
142 let bar = ProgressBar::new(self.0.len() as u64);
143 bar.set_style(
144 ProgressStyle::default_bar()
145 .template(&format!(
146 "{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
147 self.1
148 ))
149 .unwrap()
150 .progress_chars("#>-"),
151 );
152 self.2.add(bar.clone());
153 ParProgress {
154 iter: self.0.into_par_iter(),
155 bar,
156 }
157 }
158}
159
160impl<'a, T, const COLOR: char> NiceProgressBar<'a, T, COLOR>
161where
162 T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
163 <T as IntoParallelIterator>::Item: Send + 'a,
164 T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
165 + IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
166 + Send,
167 T: IntoParallelIterator<Item = <T as Iterator>::Item>,
168{
169 pub fn run<F, U>(self, _is_parallel: bool, f: F) -> candle_core::Result<Vec<U>>
174 where
175 F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
176 U: Send,
177 {
178 self.into_iter().map(f).collect()
184 }
185
186 pub fn par_iter_if_isq<F, U>(self, f: F) -> candle_core::Result<Vec<U>>
190 where
191 F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
192 U: Send,
193 {
194 self.run(get_immediate_isq().is_some_and(|x| x.ty.is_some()), f)
195 }
196}