1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11
12use anyhow::Result;
13use candle_core::{quantized, Context, Device, Tensor};
14use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
15use itertools::Itertools;
16use mistralrs_quant::{
17 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
18 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
19 ReplicatedLayer, RowParallelLayer, UnquantLinear,
20};
21use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
22use regex::Regex;
23use serde::Deserialize;
24use tokenizers::Tokenizer;
25use tracing::{info, warn};
26
27use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
28
29pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
30
31pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
60 let tp = match s.to_lowercase().as_str() {
61 "2" if cfg!(feature = "metal") => IsqType::AFQ2,
62 "2" if !cfg!(feature = "metal") => IsqType::Q2K,
63 "3" if cfg!(feature = "metal") => IsqType::AFQ3,
64 "3" if !cfg!(feature = "metal") => IsqType::Q3K,
65 "4" if cfg!(feature = "metal") => IsqType::AFQ4,
66 "4" if !cfg!(feature = "metal") => IsqType::Q4K,
67 "5" => IsqType::Q5K,
68 "6" if cfg!(feature = "metal") => IsqType::AFQ6,
69 "6" if !cfg!(feature = "metal") => IsqType::Q6K,
70 "8" if cfg!(feature = "metal") => IsqType::AFQ8,
71 "8" if !cfg!(feature = "metal") => IsqType::Q8_0,
72 "q4_0" => IsqType::Q4_0,
73 "q4_1" => IsqType::Q4_1,
74 "q5_0" => IsqType::Q5_0,
75 "q5_1" => IsqType::Q5_1,
76 "q8_0" => IsqType::Q8_0,
77 "q8_1" => IsqType::Q8_1,
78 "q2k" => IsqType::Q2K,
79 "q3k" => IsqType::Q3K,
80 "q4k" => IsqType::Q4K,
81 "q5k" => IsqType::Q5K,
82 "q6k" => IsqType::Q6K,
83 "q8k" => IsqType::Q8K,
84 "hqq8" => IsqType::HQQ8,
85 "hqq4" => IsqType::HQQ4,
86 "fp8" => IsqType::F8E4M3,
87 "afq8" => IsqType::AFQ8,
88 "afq6" => IsqType::AFQ6,
89 "afq4" => IsqType::AFQ4,
90 "afq3" => IsqType::AFQ3,
91 "afq2" => IsqType::AFQ2,
92 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
96 };
97 #[cfg(feature = "cuda")]
98 {
99 if !matches!(
100 tp,
101 IsqType::Q4_0
102 | IsqType::Q4_1
103 | IsqType::Q5_0
104 | IsqType::Q5_1
105 | IsqType::Q8_0
106 | IsqType::Q2K
107 | IsqType::Q3K
108 | IsqType::Q4K
109 | IsqType::Q5K
110 | IsqType::Q6K
111 | IsqType::HQQ8
112 | IsqType::HQQ4
113 | IsqType::F8E4M3 ) {
117 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
118 }
119 }
120 Ok(tp)
121}
122
123#[derive(Clone, Debug, Copy, Default, Deserialize)]
124pub enum IsqOrganization {
125 #[default]
126 #[serde(rename = "default")]
127 Default,
128 #[serde(rename = "moqe")]
131 MoeExpertsOnly,
132}
133
134impl FromStr for IsqOrganization {
135 type Err = String;
136 fn from_str(s: &str) -> Result<Self, Self::Err> {
137 match s {
138 "default" => Ok(Self::Default),
139 "moqe" => Ok(Self::MoeExpertsOnly),
140 other => Err(format!(
141 "Expected ISQ organization `default` or `moqe`, got `{other}`"
142 )),
143 }
144 }
145}
146
147pub struct UqffFullSer<'a> {
148 pub tokenizer: &'a Tokenizer,
149 pub template_filename: &'a Option<PathBuf>,
150 pub generation_config: Option<&'a PathBuf>,
151 pub config: String,
152 pub processor_filename: &'a Option<PathBuf>,
153 pub preprocessor_filename: &'a Option<PathBuf>,
154}
155
156#[derive(Debug, Clone, Copy)]
157pub enum ImatrixDataSource<'a> {
158 File(&'a PathBuf),
159 Collected,
160}
161
162pub trait IsqModel {
163 #[allow(clippy::type_complexity)]
165 fn get_layers(
166 &mut self,
167 ) -> (
168 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
169 &dyn DeviceMapper,
170 );
171
172 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
174 let layers = self
175 .get_layers()
176 .0
177 .into_iter()
178 .map(|(layer, _)| layer)
179 .collect::<Vec<_>>();
180 for layer in layers {
181 Arc::get_mut(layer).unwrap().begin_track_stats()?;
182 }
183 Ok(())
184 }
185
186 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
188 let layers = self
189 .get_layers()
190 .0
191 .into_iter()
192 .enumerate()
193 .map(|(i, (layer, _))| (i, layer))
194 .collect::<Vec<_>>();
195 let mut data = HashMap::new();
196 for (i, layer) in layers {
197 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
198 }
199 Ok(CollectedImatrixData(data))
200 }
201
202 #[allow(clippy::type_complexity)]
205 fn get_layers_moe_experts_only(
206 &mut self,
207 ) -> (
208 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
209 &dyn DeviceMapper,
210 ) {
211 self.get_layers()
212 }
213
214 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
217 let layers = self
218 .get_layers()
219 .0
220 .into_iter()
221 .map(|(layer, _)| layer)
222 .collect::<Vec<_>>();
223 for layer in layers {
224 Arc::get_mut(layer).unwrap().begin_track_stats()?;
225 }
226 Ok(())
227 }
228
229 fn extract_imatrix_data_moe_experts_only(
232 &mut self,
233 ) -> candle_core::Result<CollectedImatrixData> {
234 let layers = self
235 .get_layers()
236 .0
237 .into_iter()
238 .enumerate()
239 .map(|(i, (layer, _))| (i, layer))
240 .collect::<Vec<_>>();
241 let mut data = HashMap::new();
242 for (i, layer) in layers {
243 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
244 }
245 Ok(CollectedImatrixData(data))
246 }
247
248 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
255 candle_core::bail!("This model does not support quantizing with an imatrix.");
257 }
258
259 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
261
262 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
264 None
265 }
266
267 #[allow(clippy::too_many_arguments)]
272 fn quantize(
273 &mut self,
274 dtype: Option<IsqType>,
275 device: Device,
276 topology: Option<&Topology>,
277 silent: bool,
278 imatrix_source: Option<ImatrixDataSource<'_>>,
279 organization: IsqOrganization,
280 write_artifacts: Option<&PathBuf>,
281 full_ser: UqffFullSer<'_>,
282 multi_progress: Arc<MultiProgress>,
283 ) -> candle_core::Result<()> {
284 {
285 let imatrix_to_weight = match imatrix_source {
286 Some(ImatrixDataSource::File(imatrix)) => {
287 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
288 "Expected an extension for the imatrix source file.",
289 ))?;
290 if ext == "cimatrix" {
291 info!(
292 "Loading collected imatrix source file: `{}`",
293 imatrix.display()
294 );
295 Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
296 } else if ext == "imatrix" {
297 info!(
298 "Loading GGUF-format imatrix source file: `{}`",
299 imatrix.display()
300 );
301 let mut imatrix_data =
302 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
303 let imatrix_mapping = self
304 .imatrix_names()?
305 .into_iter()
306 .enumerate()
307 .collect::<HashMap<_, _>>();
308
309 let layer_to_weight = imatrix_mapping
310 .into_iter()
311 .map(|(i, name)| {
312 if let Some(name) = name {
313 (i, Some(imatrix_data.remove(&name).unwrap()))
314 } else {
315 (i, None)
316 }
317 })
318 .collect::<HashMap<_, _>>();
319 info!(
320 "Quantizing with imatrix file `{}`, {} imatrix weights",
321 imatrix.display(),
322 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
323 );
324 Some(layer_to_weight)
325 } else {
326 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
327 info!(
328 "Loading GGUF-format imatrix source file: `{}`",
329 imatrix.display()
330 );
331
332 let mut imatrix_data =
333 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
334 let imatrix_mapping = self
335 .imatrix_names()?
336 .into_iter()
337 .enumerate()
338 .collect::<HashMap<_, _>>();
339
340 let layer_to_weight = imatrix_mapping
341 .into_iter()
342 .map(|(i, name)| {
343 if let Some(name) = name {
344 (i, Some(imatrix_data.remove(&name).unwrap()))
345 } else {
346 (i, None)
347 }
348 })
349 .collect::<HashMap<_, _>>();
350 info!(
351 "Quantizing with imatrix file `{}`, {} imatrix weights",
352 imatrix.display(),
353 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
354 );
355 Some(layer_to_weight)
356 }
357 }
358 Some(ImatrixDataSource::Collected) => {
359 let data = match organization {
360 IsqOrganization::Default => self.extract_imatrix_data()?,
361 IsqOrganization::MoeExpertsOnly => {
362 self.extract_imatrix_data_moe_experts_only()?
363 }
364 };
365 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
367 let save_path = format!("collected-{count}.cimatrix");
368 info!("Saving collected imatrix data to `{save_path}`");
369 data.save_imatrix(save_path)?;
370 info!("Quantizing with collected imatrix data, {count} imatrix weights");
371 Some(data.0)
372 }
373 None => {
374 None
376 }
377 };
378
379 let (mut tensors, mapper) = match organization {
380 IsqOrganization::Default => self.get_layers(),
381 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
382 };
383
384 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
385 if let Some(mut imatrix_to_weight) = imatrix_to_weight {
386 let ordered_keys = imatrix_to_weight
387 .keys()
388 .copied()
389 .sorted()
390 .collect::<Vec<_>>();
391 ordered_keys
392 .into_iter()
393 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
394 .collect()
395 } else {
396 vec![None; tensors.len()]
397 };
398
399 let total_tensors = tensors.len();
400 let n_quantized = AtomicUsize::new(0);
401 if let Some(topology) = topology {
402 let mut dtypes = HashSet::new();
403 for layer in topology.0.iter().flatten() {
404 if let LayerTopology {
405 isq: Some(isq_dtype),
406 device: _,
407 } = layer
408 {
409 dtypes.insert(isq_dtype);
410 }
411 }
412 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
413 } else {
414 info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
415 }
416 let bar = ProgressBar::new(total_tensors as u64);
417 bar.set_style(
418 ProgressStyle::default_bar()
419 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
420 .unwrap()
421 .progress_chars("#>-"),
422 );
423 multi_progress.add(bar.clone());
424
425 let layers = topology.map(|x| {
426 x.0.iter()
427 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
428 .collect::<Vec<_>>()
429 });
430
431 let mut devices_and_dtypes = Vec::new();
432 for (_, layer_num) in &tensors {
433 let device = if let Some(ref layers) = layers {
434 if let Some(layer) = layer_num {
435 layers
436 .get(*layer)
437 .as_ref()
438 .map(|x| x.1.clone())
439 .unwrap_or(Some(device.clone()))
440 .unwrap_or(device.clone())
441 } else {
442 device.clone()
443 }
444 } else if let Some(layer_num) = layer_num {
445 mapper
446 .device_for(*layer_num, false)
447 .cloned()
448 .unwrap_or(device.clone())
449 } else {
450 device.clone()
451 };
452 let dtype = if let Some(ref layers) = layers {
453 if let Some(layer) = layer_num {
454 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
455 } else {
456 dtype
457 }
458 } else {
459 dtype
460 };
461 devices_and_dtypes.push((device, dtype));
462 }
463
464 let t_start = Instant::now();
465
466 use rayon::iter::IntoParallelRefIterator;
467
468 let mut minimum_max_threads = {
470 let current_rayon_threads = rayon::current_num_threads();
471 if let Some(dtype) = dtype {
472 dtype
473 .get_max_isq_cpu_threads()
474 .map(usize::from)
475 .unwrap_or(current_rayon_threads)
476 } else {
477 current_rayon_threads
478 }
479 };
480 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
481 minimum_max_threads = 1;
482 }
483
484 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
485 minimum_max_threads = 1;
487 }
488
489 info!("Applying ISQ on {minimum_max_threads} threads.");
490
491 let pool = rayon::ThreadPoolBuilder::new()
492 .num_threads(minimum_max_threads)
493 .build()
494 .map_err(candle_core::Error::msg)?;
495
496 let guard = QuantizeOntoGuard::new();
497
498 pool.install(|| {
499 use indicatif::ParallelProgressIterator;
500 use rayon::iter::{
501 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
502 };
503 if silent {
504 tensors
505 .par_iter_mut()
506 .zip(devices_and_dtypes)
507 .zip(imatrix_to_weight)
508 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
509 **tensor = tensor
510 .clone()
511 .apply_isq(
512 dtype,
513 device.clone(),
514 &n_quantized,
515 imatrix_weight,
516 guard.clone(),
517 )
518 .unwrap();
519 device.synchronize().unwrap();
520 });
521 } else {
522 tensors
523 .par_iter_mut()
524 .zip(devices_and_dtypes)
525 .zip(imatrix_to_weight)
526 .progress_with(bar)
527 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
528 **tensor = tensor
529 .clone()
530 .apply_isq(
531 dtype,
532 device.clone(),
533 &n_quantized,
534 imatrix_weight,
535 guard.clone(),
536 )
537 .unwrap();
538 device.synchronize().unwrap();
539 });
540 }
541 });
542
543 if let Some(serialized) = write_artifacts {
544 info!(
545 "Serializing {total_tensors} ISQ tensors to `{}`.",
546 serialized.display()
547 );
548
549 if serialized.extension().is_none_or(|ext| ext != "uqff") {
550 candle_core::bail!("UQFF output path extension must be `.uqff`",);
551 }
552
553 let bar = ProgressBar::new(total_tensors as u64);
554 bar.set_style(
555 ProgressStyle::default_bar()
556 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
557 .unwrap()
558 .progress_chars("#>-"),
559 );
560
561 #[cfg(not(feature = "metal"))]
562 let n_threads = 2;
563 #[cfg(feature = "metal")]
564 let n_threads = 1;
565
566 let pool = rayon::ThreadPoolBuilder::new()
567 .num_threads(n_threads)
568 .build()
569 .map_err(candle_core::Error::msg)?;
570
571 let quantized_values = pool.install(|| {
572 if silent {
573 tensors
574 .par_iter()
575 .enumerate()
576 .filter(|(_, (layer, _))| layer.isq_serde_supported())
577 .map(|(i, (layer, _))| {
578 Ok((
579 i.to_string(),
580 Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
581 ))
582 })
583 .collect::<candle_core::Result<Vec<_>>>()
584 } else {
585 tensors
586 .par_iter()
587 .enumerate()
588 .progress_with(bar)
589 .filter(|(_, (layer, _))| layer.isq_serde_supported())
590 .map(|(i, (layer, _))| {
591 Ok((
592 i.to_string(),
593 Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
594 ))
595 })
596 .collect::<candle_core::Result<Vec<_>>>()
597 }
598 });
599
600 let parent = serialized
601 .parent()
602 .context("Target UQFF path must have a filename!")?;
603
604 std::fs::create_dir_all(parent)?;
605
606 safetensors::serialize_to_file(quantized_values?, &None, serialized)?;
607
608 let residual = match organization {
609 IsqOrganization::Default => self.residual_tensors(),
610 IsqOrganization::MoeExpertsOnly => self
611 .residual_tensors_moe_experts_only()
612 .unwrap_or(self.residual_tensors()),
613 };
614
615 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
616 let config_out = parent.join("config.json");
617 let tokenizer_out = parent.join("tokenizer.json");
618 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
619 let gen_cfg_out = parent.join("generation_config.json");
620 let processor_out = parent.join("processor_config.json");
621 let preprocessor_out = parent.join("preprocessor_config.json");
622
623 info!(
624 "Serializing {} residual tensors to `{}`.",
625 residual.len(),
626 residual_out.display()
627 );
628
629 safetensors::serialize_to_file(residual, &None, &residual_out)?;
630
631 let UqffFullSer {
632 tokenizer,
633 template_filename,
634 generation_config,
635 config,
636 processor_filename,
637 preprocessor_filename,
638 } = full_ser;
639
640 info!("Serializing configuration to `{}`.", config_out.display());
641
642 std::fs::write(config_out, config)?;
643
644 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
645
646 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
647 .map_err(candle_core::Error::msg)?;
648
649 if let Some(template_filename) = template_filename {
650 info!(
651 "Serializing tokenizer config to `{}`.",
652 tokenizer_cfg_out.display()
653 );
654
655 let template =
656 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
657 std::fs::write(&tokenizer_cfg_out, template)
658 .map_err(candle_core::Error::msg)?;
659 }
660
661 if let Some(generation_config) = generation_config {
662 info!(
663 "Serializing generation config to `{}`.",
664 gen_cfg_out.display()
665 );
666
667 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
668 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
669 }
670
671 if let Some(processor_config) = processor_filename {
672 info!(
673 "Serializing processor config to `{}`.",
674 processor_out.display()
675 );
676
677 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
678 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
679 }
680
681 if let Some(preprocessor_config) = preprocessor_filename {
682 info!(
683 "Serializing preprocessor config to `{}`.",
684 preprocessor_out.display()
685 );
686
687 let cfg =
688 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
689 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
690 }
691 }
692 let delta = Instant::now().duration_since(t_start).as_secs_f32();
693 info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
694 }
695 Ok(())
696 }
697
698 fn load_from_artifacts(
699 &mut self,
700 device: Device,
701 topology: Option<&Topology>,
702 silent: bool,
703 artifacts: &PathBuf,
704 ) -> candle_core::Result<()> {
705 let (tensors, mapper) = self.get_layers();
706 let total_tensors = tensors.len();
707
708 let layers = topology.map(|x| {
709 x.0.iter()
710 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
711 .collect::<Vec<_>>()
712 });
713
714 let mut devices = Vec::new();
715 let mut comms = Vec::new();
716 for (_, layer_num) in &tensors {
717 let device = if let Some(ref layers) = layers {
718 if let Some(layer) = layer_num {
719 layers
720 .get(*layer)
721 .as_ref()
722 .map(|x| x.1.clone())
723 .unwrap_or(Some(device.clone()))
724 .unwrap_or(device.clone())
725 } else {
726 device.clone()
727 }
728 } else if let Some(layer_num) = layer_num {
729 mapper
730 .device_for(*layer_num, false)
731 .cloned()
732 .unwrap_or(device.clone())
733 } else {
734 device.clone()
735 };
736 devices.push(device);
737 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
738 }
739
740 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::new(artifacts)? };
741
742 let artifact_isqs = artifacts
743 .tensors()
744 .into_iter()
745 .map(|(name, tensor)| {
746 (
747 name.parse::<usize>()
748 .expect("Name should be parseable as usize"),
749 tensor,
750 )
751 })
752 .collect::<HashMap<_, _>>();
753
754 if artifact_isqs.len() != total_tensors {
755 candle_core::bail!(
756 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
757 artifact_isqs.len(),
758 );
759 }
760
761 let bar = ProgressBar::new(total_tensors as u64);
762 bar.set_style(
763 ProgressStyle::default_bar()
764 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
765 .unwrap()
766 .progress_chars("#>-"),
767 );
768
769 let t_start = Instant::now();
770
771 let guard = QuantizeOntoGuard::new();
772
773 if silent {
774 (0..tensors.len())
775 .into_par_iter()
776 .zip(tensors)
777 .map(|(i, (tensor, _))| {
778 if let Some(artifact) = artifact_isqs.get(&i) {
779 let artifact = artifact.data();
780
781 let comm = comms[i].clone();
782 let deserialized = match tensor.is_distributed() {
783 Some(DistributedKind::ColumnParallel) => {
784 ColumnParallelLayer::deserialize(
785 Cow::from(artifact),
786 &devices[i],
787 &comm,
788 guard.clone(),
789 )?
790 }
791 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
792 Cow::from(artifact),
793 &devices[i],
794 &comm,
795 guard.clone(),
796 )?,
797 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
798 Cow::from(artifact),
799 &devices[i],
800 &comm,
801 guard.clone(),
802 )?,
803 None => {
804 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
806 match QuantizedSerdeType::try_from(isq_type as usize)? {
807 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
808 Cow::from(artifact),
809 &devices[i],
810 &comm,
811 guard.clone(),
812 )?,
813 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
814 Cow::from(artifact),
815 &devices[i],
816 &comm,
817 guard.clone(),
818 )?,
819 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
820 Cow::from(artifact),
821 &devices[i],
822 &comm,
823 guard.clone(),
824 )?,
825 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
826 Cow::from(artifact),
827 &devices[i],
828 &comm,
829 guard.clone(),
830 )?,
831 QuantizedSerdeType::Afq => AfqLayer::deserialize(
832 Cow::from(artifact),
833 &devices[i],
834 &comm,
835 guard.clone(),
836 )?,
837 }
838 }
839 };
840 *tensor = deserialized;
841 }
842 Ok(())
843 })
844 .collect::<candle_core::Result<Vec<_>>>()?;
845 } else {
846 (0..tensors.len())
847 .into_par_iter()
848 .zip(tensors)
849 .progress_with(bar)
850 .map(|(i, (tensor, _))| {
851 if let Some(artifact) = artifact_isqs.get(&i) {
852 let artifact = artifact.data();
853
854 let comm = comms[i].clone();
855 let deserialized = match tensor.is_distributed() {
856 Some(DistributedKind::ColumnParallel) => {
857 ColumnParallelLayer::deserialize(
858 Cow::from(artifact),
859 &devices[i],
860 &comm,
861 guard.clone(),
862 )?
863 }
864 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
865 Cow::from(artifact),
866 &devices[i],
867 &comm,
868 guard.clone(),
869 )?,
870 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
871 Cow::from(artifact),
872 &devices[i],
873 &comm,
874 guard.clone(),
875 )?,
876 None => {
877 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
879 match QuantizedSerdeType::try_from(isq_type as usize)? {
880 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
881 Cow::from(artifact),
882 &devices[i],
883 &comm,
884 guard.clone(),
885 )?,
886 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
887 Cow::from(artifact),
888 &devices[i],
889 &comm,
890 guard.clone(),
891 )?,
892 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
893 Cow::from(artifact),
894 &devices[i],
895 &comm,
896 guard.clone(),
897 )?,
898 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
899 Cow::from(artifact),
900 &devices[i],
901 &comm,
902 guard.clone(),
903 )?,
904 QuantizedSerdeType::Afq => AfqLayer::deserialize(
905 Cow::from(artifact),
906 &devices[i],
907 &comm,
908 guard.clone(),
909 )?,
910 }
911 }
912 };
913 *tensor = deserialized;
914 }
915 Ok(())
916 })
917 .collect::<candle_core::Result<Vec<_>>>()?;
918 }
919
920 let delta = Instant::now().duration_since(t_start).as_secs_f32();
921 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
922
923 Ok(())
924 }
925}
926
927pub(crate) trait IsqModelLoader {
929 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
933 Ok(Vec::new())
934 }
935
936 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
940 self.isq_layer_regexes(config)
941 }
942}