1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11#[derive(Clone)]
19pub struct CowBytesView<'a> {
20 data: Cow<'a, [u8]>,
21 shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25 pub fn new(data: Cow<'a, [u8]>) -> Self {
27 let len = data.len();
28 Self { data, shape: [len] }
29 }
30}
31
32impl<'a> safetensors::tensor::View for CowBytesView<'a> {
33 fn dtype(&self) -> safetensors::tensor::Dtype {
34 safetensors::tensor::Dtype::U8
36 }
37
38 fn shape(&self) -> &[usize] {
39 &self.shape
40 }
41
42 fn data(&self) -> Cow<[u8]> {
43 assert!(matches!(self.data, Cow::Borrowed(_)));
44 self.data.clone()
46 }
47
48 fn data_len(&self) -> usize {
49 self.data.len()
50 }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60 ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
69
70pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
71const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
73pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
74
75pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
104 let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
105 let tp = match s.to_lowercase().as_str() {
106 "2" if is_metal => IsqType::AFQ2,
107 "2" if !is_metal => IsqType::Q2K,
108 "3" if is_metal => IsqType::AFQ3,
109 "3" if !is_metal => IsqType::Q3K,
110 "4" if is_metal => IsqType::AFQ4,
111 "4" if !is_metal => IsqType::Q4K,
112 "5" => IsqType::Q5K,
113 "6" if is_metal => IsqType::AFQ6,
114 "6" if !is_metal => IsqType::Q6K,
115 "8" if is_metal => IsqType::AFQ8,
116 "8" if !is_metal => IsqType::Q8_0,
117 "q4_0" => IsqType::Q4_0,
118 "q4_1" => IsqType::Q4_1,
119 "q5_0" => IsqType::Q5_0,
120 "q5_1" => IsqType::Q5_1,
121 "q8_0" => IsqType::Q8_0,
122 "q8_1" => IsqType::Q8_1,
123 "q2k" => IsqType::Q2K,
124 "q3k" => IsqType::Q3K,
125 "q4k" => IsqType::Q4K,
126 "q5k" => IsqType::Q5K,
127 "q6k" => IsqType::Q6K,
128 "q8k" => IsqType::Q8K,
129 "hqq8" => IsqType::HQQ8,
130 "hqq4" => IsqType::HQQ4,
131 "fp8" => IsqType::F8E4M3,
132 "afq8" => IsqType::AFQ8,
133 "afq6" => IsqType::AFQ6,
134 "afq4" => IsqType::AFQ4,
135 "afq3" => IsqType::AFQ3,
136 "afq2" => IsqType::AFQ2,
137 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
141 };
142 #[cfg(feature = "cuda")]
143 {
144 if !matches!(
145 tp,
146 IsqType::Q4_0
147 | IsqType::Q4_1
148 | IsqType::Q5_0
149 | IsqType::Q5_1
150 | IsqType::Q8_0
151 | IsqType::Q2K
152 | IsqType::Q3K
153 | IsqType::Q4K
154 | IsqType::Q5K
155 | IsqType::Q6K
156 | IsqType::HQQ8
157 | IsqType::HQQ4
158 | IsqType::F8E4M3 ) {
162 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
163 }
164 }
165 Ok(tp)
166}
167
168#[derive(Clone, Debug, Copy, Default, Deserialize)]
169pub enum IsqOrganization {
170 #[default]
171 #[serde(rename = "default")]
172 Default,
173 #[serde(rename = "moqe")]
176 MoeExpertsOnly,
177}
178
179impl FromStr for IsqOrganization {
180 type Err = String;
181 fn from_str(s: &str) -> Result<Self, Self::Err> {
182 match s {
183 "default" => Ok(Self::Default),
184 "moqe" => Ok(Self::MoeExpertsOnly),
185 other => Err(format!(
186 "Expected ISQ organization `default` or `moqe`, got `{other}`"
187 )),
188 }
189 }
190}
191
192pub struct UqffFullSer<'a> {
193 pub tokenizer: &'a Tokenizer,
194 pub template_filename: &'a Option<PathBuf>,
195 pub generation_config: Option<&'a PathBuf>,
196 pub config: String,
197 pub processor_filename: &'a Option<PathBuf>,
198 pub preprocessor_filename: &'a Option<PathBuf>,
199}
200
201#[derive(Debug, Clone, Copy)]
202pub enum ImatrixDataSource<'a> {
203 File(&'a PathBuf),
204 Collected,
205}
206
207pub trait IsqModel {
208 #[allow(clippy::type_complexity)]
210 fn get_layers(
211 &mut self,
212 ) -> (
213 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
214 &dyn DeviceMapper,
215 );
216
217 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
219 let layers = self
220 .get_layers()
221 .0
222 .into_iter()
223 .map(|(layer, _)| layer)
224 .collect::<Vec<_>>();
225 for layer in layers {
226 Arc::get_mut(layer).unwrap().begin_track_stats()?;
227 }
228 Ok(())
229 }
230
231 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
233 let layers = self
234 .get_layers()
235 .0
236 .into_iter()
237 .enumerate()
238 .map(|(i, (layer, _))| (i, layer))
239 .collect::<Vec<_>>();
240 let mut data = HashMap::new();
241 for (i, layer) in layers {
242 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
243 }
244 Ok(CollectedImatrixData(data))
245 }
246
247 #[allow(clippy::type_complexity)]
250 fn get_layers_moe_experts_only(
251 &mut self,
252 ) -> (
253 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
254 &dyn DeviceMapper,
255 ) {
256 self.get_layers()
257 }
258
259 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
262 let layers = self
263 .get_layers()
264 .0
265 .into_iter()
266 .map(|(layer, _)| layer)
267 .collect::<Vec<_>>();
268 for layer in layers {
269 Arc::get_mut(layer).unwrap().begin_track_stats()?;
270 }
271 Ok(())
272 }
273
274 fn extract_imatrix_data_moe_experts_only(
277 &mut self,
278 ) -> candle_core::Result<CollectedImatrixData> {
279 let layers = self
280 .get_layers()
281 .0
282 .into_iter()
283 .enumerate()
284 .map(|(i, (layer, _))| (i, layer))
285 .collect::<Vec<_>>();
286 let mut data = HashMap::new();
287 for (i, layer) in layers {
288 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
289 }
290 Ok(CollectedImatrixData(data))
291 }
292
293 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
300 candle_core::bail!("This model does not support quantizing with an imatrix.");
302 }
303
304 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
306
307 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
309 None
310 }
311
312 #[allow(clippy::too_many_arguments)]
317 fn quantize(
318 &mut self,
319 dtype: Option<IsqType>,
320 device: Device,
321 topology: Option<&Topology>,
322 silent: bool,
323 imatrix_source: Option<ImatrixDataSource<'_>>,
324 organization: IsqOrganization,
325 write_artifacts: Option<&PathBuf>,
326 full_ser: UqffFullSer<'_>,
327 multi_progress: Arc<MultiProgress>,
328 ) -> candle_core::Result<()> {
329 {
330 let imatrix_to_weight = match imatrix_source {
331 Some(ImatrixDataSource::File(imatrix)) => {
332 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
333 "Expected an extension for the imatrix source file.",
334 ))?;
335 if ext == "cimatrix" {
336 info!(
337 "Loading collected imatrix source file: `{}`",
338 imatrix.display()
339 );
340 Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
341 } else if ext == "imatrix" {
342 info!(
343 "Loading GGUF-format imatrix source file: `{}`",
344 imatrix.display()
345 );
346 let mut imatrix_data =
347 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
348 let imatrix_mapping = self
349 .imatrix_names()?
350 .into_iter()
351 .enumerate()
352 .collect::<HashMap<_, _>>();
353
354 let layer_to_weight = imatrix_mapping
355 .into_iter()
356 .map(|(i, name)| {
357 if let Some(name) = name {
358 (i, Some(imatrix_data.remove(&name).unwrap()))
359 } else {
360 (i, None)
361 }
362 })
363 .collect::<HashMap<_, _>>();
364 info!(
365 "Quantizing with imatrix file `{}`, {} imatrix weights",
366 imatrix.display(),
367 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
368 );
369 Some(layer_to_weight)
370 } else {
371 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
372 info!(
373 "Loading GGUF-format imatrix source file: `{}`",
374 imatrix.display()
375 );
376
377 let mut imatrix_data =
378 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
379 let imatrix_mapping = self
380 .imatrix_names()?
381 .into_iter()
382 .enumerate()
383 .collect::<HashMap<_, _>>();
384
385 let layer_to_weight = imatrix_mapping
386 .into_iter()
387 .map(|(i, name)| {
388 if let Some(name) = name {
389 (i, Some(imatrix_data.remove(&name).unwrap()))
390 } else {
391 (i, None)
392 }
393 })
394 .collect::<HashMap<_, _>>();
395 info!(
396 "Quantizing with imatrix file `{}`, {} imatrix weights",
397 imatrix.display(),
398 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
399 );
400 Some(layer_to_weight)
401 }
402 }
403 Some(ImatrixDataSource::Collected) => {
404 let data = match organization {
405 IsqOrganization::Default => self.extract_imatrix_data()?,
406 IsqOrganization::MoeExpertsOnly => {
407 self.extract_imatrix_data_moe_experts_only()?
408 }
409 };
410 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
412 let save_path = format!("collected-{count}.cimatrix");
413 info!("Saving collected imatrix data to `{save_path}`");
414 data.save_imatrix(save_path)?;
415 info!("Quantizing with collected imatrix data, {count} imatrix weights");
416 Some(data.0)
417 }
418 None => {
419 None
421 }
422 };
423
424 let (mut tensors, mapper) = match organization {
425 IsqOrganization::Default => self.get_layers(),
426 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
427 };
428
429 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
430 if let Some(mut imatrix_to_weight) = imatrix_to_weight {
431 let ordered_keys = imatrix_to_weight
432 .keys()
433 .copied()
434 .sorted()
435 .collect::<Vec<_>>();
436 ordered_keys
437 .into_iter()
438 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
439 .collect()
440 } else {
441 vec![None; tensors.len()]
442 };
443
444 let total_tensors = tensors.len();
445 let n_quantized = AtomicUsize::new(0);
446 if let Some(topology) = topology {
447 let mut dtypes = HashSet::new();
448 for layer in topology.0.iter().flatten() {
449 if let LayerTopology {
450 isq: Some(isq_dtype),
451 device: _,
452 } = layer
453 {
454 dtypes.insert(isq_dtype);
455 }
456 }
457 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
458 } else {
459 info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
460 }
461 let bar = ProgressBar::new(total_tensors as u64);
462 bar.set_style(
463 ProgressStyle::default_bar()
464 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
465 .unwrap()
466 .progress_chars("#>-"),
467 );
468 multi_progress.add(bar.clone());
469
470 let layers = topology.map(|x| {
471 x.0.iter()
472 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
473 .collect::<Vec<_>>()
474 });
475
476 let mut devices_and_dtypes = Vec::new();
477 for (_, layer_num) in &tensors {
478 let device = if let Some(ref layers) = layers {
479 if let Some(layer) = layer_num {
480 layers
481 .get(*layer)
482 .as_ref()
483 .map(|x| x.1.clone())
484 .unwrap_or(Some(device.clone()))
485 .unwrap_or(device.clone())
486 } else {
487 device.clone()
488 }
489 } else if let Some(layer_num) = layer_num {
490 mapper
491 .device_for(*layer_num, false)
492 .cloned()
493 .unwrap_or(device.clone())
494 } else {
495 device.clone()
496 };
497 let dtype = if let Some(ref layers) = layers {
498 if let Some(layer) = layer_num {
499 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
500 } else {
501 dtype
502 }
503 } else {
504 dtype
505 };
506 devices_and_dtypes.push((device, dtype));
507 }
508
509 let t_start = Instant::now();
510
511 use rayon::iter::IntoParallelRefIterator;
512
513 let mut minimum_max_threads = {
515 let current_rayon_threads = rayon::current_num_threads();
516 if let Some(dtype) = dtype {
517 dtype
518 .get_max_isq_cpu_threads()
519 .map(usize::from)
520 .unwrap_or(current_rayon_threads)
521 } else {
522 current_rayon_threads
523 }
524 };
525 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
526 minimum_max_threads = 1;
527 }
528
529 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
530 minimum_max_threads = 1;
532 }
533
534 info!("Applying ISQ on {minimum_max_threads} threads.");
535
536 let pool = rayon::ThreadPoolBuilder::new()
537 .num_threads(minimum_max_threads)
538 .build()
539 .map_err(candle_core::Error::msg)?;
540
541 let guard = QuantizeOntoGuard::new();
542
543 pool.install(|| {
544 use indicatif::ParallelProgressIterator;
545 use rayon::iter::{
546 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
547 };
548 if silent {
549 tensors
550 .par_iter_mut()
551 .zip(devices_and_dtypes)
552 .zip(imatrix_to_weight)
553 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
554 **tensor = tensor
555 .clone()
556 .apply_isq(
557 dtype,
558 device.clone(),
559 &n_quantized,
560 imatrix_weight,
561 guard.clone(),
562 )
563 .unwrap();
564 device.synchronize().unwrap();
565 });
566 } else {
567 tensors
568 .par_iter_mut()
569 .zip(devices_and_dtypes)
570 .zip(imatrix_to_weight)
571 .progress_with(bar)
572 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
573 **tensor = tensor
574 .clone()
575 .apply_isq(
576 dtype,
577 device.clone(),
578 &n_quantized,
579 imatrix_weight,
580 guard.clone(),
581 )
582 .unwrap();
583 device.synchronize().unwrap();
584 });
585 }
586 });
587
588 if let Some(serialized) = write_artifacts {
589 info!(
590 "Serializing {total_tensors} ISQ tensors to `{}`.",
591 serialized.display()
592 );
593
594 if serialized.extension().is_none_or(|ext| ext != "uqff") {
595 candle_core::bail!("UQFF output path extension must be `.uqff`",);
596 }
597
598 let bar = ProgressBar::new(total_tensors as u64);
599 bar.set_style(
600 ProgressStyle::default_bar()
601 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
602 .unwrap()
603 .progress_chars("#>-"),
604 );
605
606 #[cfg(not(feature = "metal"))]
607 let n_threads = 2;
608 #[cfg(feature = "metal")]
609 let n_threads = 1;
610
611 let pool = rayon::ThreadPoolBuilder::new()
612 .num_threads(n_threads)
613 .build()
614 .map_err(candle_core::Error::msg)?;
615
616 let quantized_values = pool.install(|| {
617 if silent {
618 tensors
619 .par_iter()
620 .enumerate()
621 .filter(|(_, (layer, _))| layer.isq_serde_supported())
622 .map(|(i, (layer, _))| {
623 Ok((
624 i.to_string(),
625 match layer.serialize()? {
626 Cow::Borrowed(_) => unreachable!(),
627 Cow::Owned(owned) => owned,
628 },
629 ))
630 })
631 .collect::<candle_core::Result<Vec<_>>>()
632 } else {
633 tensors
634 .par_iter()
635 .enumerate()
636 .progress_with(bar)
637 .filter(|(_, (layer, _))| layer.isq_serde_supported())
638 .map(|(i, (layer, _))| {
639 Ok((
640 i.to_string(),
641 match layer.serialize()? {
642 Cow::Borrowed(_) => unreachable!(),
643 Cow::Owned(owned) => owned,
644 },
645 ))
646 })
647 .collect::<candle_core::Result<Vec<_>>>()
648 }
649 });
650 let quantized_values = quantized_values?;
651
652 let parent = serialized
653 .parent()
654 .context("Target UQFF path must have a filename!")?;
655
656 std::fs::create_dir_all(parent)?;
657
658 let file_stem = serialized
659 .file_stem()
660 .context("Target UQFF path must have a file stem!")?
661 .to_string_lossy()
662 .to_string();
663
664 let mut current_chunk = Vec::new();
666 let mut current_bytes: usize = 0;
667 let mut shard_index = 0;
668
669 for (name, tensor) in quantized_values.iter() {
671 let tensor_bytes = tensor.len();
672 if !current_chunk.is_empty()
673 && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
674 {
675 let mut shard_path = parent.to_path_buf();
676 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
677 info!(
678 "Writing shard {} to `{}`",
679 shard_index,
680 shard_path.display()
681 );
682 safetensors::serialize_to_file(current_chunk.clone(), &None, &shard_path)?;
683 shard_index += 1;
684 current_chunk.clear();
685 current_bytes = 0;
686 }
687 current_bytes += tensor_bytes;
688 current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
689 }
690
691 if !current_chunk.is_empty() {
692 let mut shard_path = parent.to_path_buf();
693 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
694 info!(
695 "Writing final shard {} to `{}`",
696 shard_index,
697 shard_path.display()
698 );
699 safetensors::serialize_to_file(current_chunk.clone(), &None, &shard_path)?;
700 }
701
702 let residual = match organization {
703 IsqOrganization::Default => self.residual_tensors(),
704 IsqOrganization::MoeExpertsOnly => self
705 .residual_tensors_moe_experts_only()
706 .unwrap_or(self.residual_tensors()),
707 };
708
709 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
710 let config_out = parent.join("config.json");
711 let tokenizer_out = parent.join("tokenizer.json");
712 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
713 let gen_cfg_out = parent.join("generation_config.json");
714 let processor_out = parent.join("processor_config.json");
715 let preprocessor_out = parent.join("preprocessor_config.json");
716
717 info!(
718 "Serializing {} residual tensors to `{}`.",
719 residual.len(),
720 residual_out.display()
721 );
722
723 safetensors::serialize_to_file(residual, &None, &residual_out)?;
724
725 let UqffFullSer {
726 tokenizer,
727 template_filename,
728 generation_config,
729 config,
730 processor_filename,
731 preprocessor_filename,
732 } = full_ser;
733
734 info!("Serializing configuration to `{}`.", config_out.display());
735
736 std::fs::write(config_out, config)?;
737
738 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
739
740 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
741 .map_err(candle_core::Error::msg)?;
742
743 if let Some(template_filename) = template_filename {
744 info!(
745 "Serializing tokenizer config to `{}`.",
746 tokenizer_cfg_out.display()
747 );
748
749 let template =
750 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
751 std::fs::write(&tokenizer_cfg_out, template)
752 .map_err(candle_core::Error::msg)?;
753 }
754
755 if let Some(generation_config) = generation_config {
756 info!(
757 "Serializing generation config to `{}`.",
758 gen_cfg_out.display()
759 );
760
761 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
762 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
763 }
764
765 if let Some(processor_config) = processor_filename {
766 info!(
767 "Serializing processor config to `{}`.",
768 processor_out.display()
769 );
770
771 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
772 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
773 }
774
775 if let Some(preprocessor_config) = preprocessor_filename {
776 info!(
777 "Serializing preprocessor config to `{}`.",
778 preprocessor_out.display()
779 );
780
781 let cfg =
782 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
783 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
784 }
785 }
786 let delta = Instant::now().duration_since(t_start).as_secs_f32();
787 info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
788 }
789 Ok(())
790 }
791
792 fn load_from_artifacts(
793 &mut self,
794 device: Device,
795 topology: Option<&Topology>,
796 silent: bool,
797 artifacts: &[PathBuf],
798 ) -> candle_core::Result<()> {
799 let (tensors, mapper) = self.get_layers();
800 let total_tensors = tensors.len();
801
802 let layers = topology.map(|x| {
803 x.0.iter()
804 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
805 .collect::<Vec<_>>()
806 });
807
808 let mut devices = Vec::new();
809 let mut comms = Vec::new();
810 for (_, layer_num) in &tensors {
811 let device = if let Some(ref layers) = layers {
812 if let Some(layer) = layer_num {
813 layers
814 .get(*layer)
815 .as_ref()
816 .map(|x| x.1.clone())
817 .unwrap_or(Some(device.clone()))
818 .unwrap_or(device.clone())
819 } else {
820 device.clone()
821 }
822 } else if let Some(layer_num) = layer_num {
823 mapper
824 .device_for(*layer_num, false)
825 .cloned()
826 .unwrap_or(device.clone())
827 } else {
828 device.clone()
829 };
830 devices.push(device);
831 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
832 }
833
834 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
835
836 let artifact_isqs = artifacts
837 .tensors()
838 .into_iter()
839 .map(|(name, tensor)| {
840 (
841 name.parse::<usize>()
842 .expect("Name should be parseable as usize"),
843 tensor,
844 )
845 })
846 .collect::<HashMap<_, _>>();
847
848 if artifact_isqs.len() != total_tensors {
849 candle_core::bail!(
850 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
851 artifact_isqs.len(),
852 );
853 }
854
855 let bar = ProgressBar::new(total_tensors as u64);
856 bar.set_style(
857 ProgressStyle::default_bar()
858 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
859 .unwrap()
860 .progress_chars("#>-"),
861 );
862
863 let t_start = Instant::now();
864
865 let guard = QuantizeOntoGuard::new();
866
867 if silent {
868 (0..tensors.len())
869 .into_par_iter()
870 .zip(tensors)
871 .map(|(i, (tensor, _))| {
872 if let Some(artifact) = artifact_isqs.get(&i) {
873 let artifact = artifact.data();
874
875 let comm = comms[i].clone();
876 let deserialized = match tensor.is_distributed() {
877 Some(DistributedKind::ColumnParallel) => {
878 ColumnParallelLayer::deserialize(
879 Cow::from(artifact),
880 &devices[i],
881 &comm,
882 guard.clone(),
883 )?
884 }
885 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
886 Cow::from(artifact),
887 &devices[i],
888 &comm,
889 guard.clone(),
890 )?,
891 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
892 Cow::from(artifact),
893 &devices[i],
894 &comm,
895 guard.clone(),
896 )?,
897 None => {
898 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
900 match QuantizedSerdeType::try_from(isq_type as usize)? {
901 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
902 Cow::from(artifact),
903 &devices[i],
904 &comm,
905 guard.clone(),
906 )?,
907 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
908 Cow::from(artifact),
909 &devices[i],
910 &comm,
911 guard.clone(),
912 )?,
913 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
914 Cow::from(artifact),
915 &devices[i],
916 &comm,
917 guard.clone(),
918 )?,
919 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
920 Cow::from(artifact),
921 &devices[i],
922 &comm,
923 guard.clone(),
924 )?,
925 QuantizedSerdeType::Afq => AfqLayer::deserialize(
926 Cow::from(artifact),
927 &devices[i],
928 &comm,
929 guard.clone(),
930 )?,
931 }
932 }
933 };
934 *tensor = deserialized;
935 }
936 Ok(())
937 })
938 .collect::<candle_core::Result<Vec<_>>>()?;
939 } else {
940 (0..tensors.len())
941 .into_par_iter()
942 .zip(tensors)
943 .progress_with(bar)
944 .map(|(i, (tensor, _))| {
945 if let Some(artifact) = artifact_isqs.get(&i) {
946 let artifact = artifact.data();
947
948 let comm = comms[i].clone();
949 let deserialized = match tensor.is_distributed() {
950 Some(DistributedKind::ColumnParallel) => {
951 ColumnParallelLayer::deserialize(
952 Cow::from(artifact),
953 &devices[i],
954 &comm,
955 guard.clone(),
956 )?
957 }
958 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
959 Cow::from(artifact),
960 &devices[i],
961 &comm,
962 guard.clone(),
963 )?,
964 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
965 Cow::from(artifact),
966 &devices[i],
967 &comm,
968 guard.clone(),
969 )?,
970 None => {
971 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
973 match QuantizedSerdeType::try_from(isq_type as usize)? {
974 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
975 Cow::from(artifact),
976 &devices[i],
977 &comm,
978 guard.clone(),
979 )?,
980 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
981 Cow::from(artifact),
982 &devices[i],
983 &comm,
984 guard.clone(),
985 )?,
986 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
987 Cow::from(artifact),
988 &devices[i],
989 &comm,
990 guard.clone(),
991 )?,
992 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
993 Cow::from(artifact),
994 &devices[i],
995 &comm,
996 guard.clone(),
997 )?,
998 QuantizedSerdeType::Afq => AfqLayer::deserialize(
999 Cow::from(artifact),
1000 &devices[i],
1001 &comm,
1002 guard.clone(),
1003 )?,
1004 }
1005 }
1006 };
1007 *tensor = deserialized;
1008 }
1009 Ok(())
1010 })
1011 .collect::<candle_core::Result<Vec<_>>>()?;
1012 }
1013
1014 let delta = Instant::now().duration_since(t_start).as_secs_f32();
1015 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1016
1017 Ok(())
1018 }
1019}
1020
1021pub(crate) trait IsqModelLoader {
1023 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1027 Ok(Vec::new())
1028 }
1029
1030 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1034 self.isq_layer_regexes(config)
1035 }
1036
1037 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1041 Ok(Vec::new())
1042 }
1043
1044 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1048 self.isq_layer_regexes(config)
1049 }
1050}