1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11
12use anyhow::Result;
13use candle_core::{quantized, Context, Device, Tensor};
14use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
15use itertools::Itertools;
16use mistralrs_quant::{
17 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
18 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
19 ReplicatedLayer, RowParallelLayer, UnquantLinear,
20};
21use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
22use regex::Regex;
23use serde::Deserialize;
24use tokenizers::Tokenizer;
25use tracing::{info, warn};
26
27use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
28
29pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
30const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
32pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
33
34pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
63 let tp = match s.to_lowercase().as_str() {
64 "2" if cfg!(feature = "metal") => IsqType::AFQ2,
65 "2" if !cfg!(feature = "metal") => IsqType::Q2K,
66 "3" if cfg!(feature = "metal") => IsqType::AFQ3,
67 "3" if !cfg!(feature = "metal") => IsqType::Q3K,
68 "4" if cfg!(feature = "metal") => IsqType::AFQ4,
69 "4" if !cfg!(feature = "metal") => IsqType::Q4K,
70 "5" => IsqType::Q5K,
71 "6" if cfg!(feature = "metal") => IsqType::AFQ6,
72 "6" if !cfg!(feature = "metal") => IsqType::Q6K,
73 "8" if cfg!(feature = "metal") => IsqType::AFQ8,
74 "8" if !cfg!(feature = "metal") => IsqType::Q8_0,
75 "q4_0" => IsqType::Q4_0,
76 "q4_1" => IsqType::Q4_1,
77 "q5_0" => IsqType::Q5_0,
78 "q5_1" => IsqType::Q5_1,
79 "q8_0" => IsqType::Q8_0,
80 "q8_1" => IsqType::Q8_1,
81 "q2k" => IsqType::Q2K,
82 "q3k" => IsqType::Q3K,
83 "q4k" => IsqType::Q4K,
84 "q5k" => IsqType::Q5K,
85 "q6k" => IsqType::Q6K,
86 "q8k" => IsqType::Q8K,
87 "hqq8" => IsqType::HQQ8,
88 "hqq4" => IsqType::HQQ4,
89 "fp8" => IsqType::F8E4M3,
90 "afq8" => IsqType::AFQ8,
91 "afq6" => IsqType::AFQ6,
92 "afq4" => IsqType::AFQ4,
93 "afq3" => IsqType::AFQ3,
94 "afq2" => IsqType::AFQ2,
95 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
99 };
100 #[cfg(feature = "cuda")]
101 {
102 if !matches!(
103 tp,
104 IsqType::Q4_0
105 | IsqType::Q4_1
106 | IsqType::Q5_0
107 | IsqType::Q5_1
108 | IsqType::Q8_0
109 | IsqType::Q2K
110 | IsqType::Q3K
111 | IsqType::Q4K
112 | IsqType::Q5K
113 | IsqType::Q6K
114 | IsqType::HQQ8
115 | IsqType::HQQ4
116 | IsqType::F8E4M3 ) {
120 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
121 }
122 }
123 Ok(tp)
124}
125
126#[derive(Clone, Debug, Copy, Default, Deserialize)]
127pub enum IsqOrganization {
128 #[default]
129 #[serde(rename = "default")]
130 Default,
131 #[serde(rename = "moqe")]
134 MoeExpertsOnly,
135}
136
137impl FromStr for IsqOrganization {
138 type Err = String;
139 fn from_str(s: &str) -> Result<Self, Self::Err> {
140 match s {
141 "default" => Ok(Self::Default),
142 "moqe" => Ok(Self::MoeExpertsOnly),
143 other => Err(format!(
144 "Expected ISQ organization `default` or `moqe`, got `{other}`"
145 )),
146 }
147 }
148}
149
150pub struct UqffFullSer<'a> {
151 pub tokenizer: &'a Tokenizer,
152 pub template_filename: &'a Option<PathBuf>,
153 pub generation_config: Option<&'a PathBuf>,
154 pub config: String,
155 pub processor_filename: &'a Option<PathBuf>,
156 pub preprocessor_filename: &'a Option<PathBuf>,
157}
158
159#[derive(Debug, Clone, Copy)]
160pub enum ImatrixDataSource<'a> {
161 File(&'a PathBuf),
162 Collected,
163}
164
165pub trait IsqModel {
166 #[allow(clippy::type_complexity)]
168 fn get_layers(
169 &mut self,
170 ) -> (
171 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
172 &dyn DeviceMapper,
173 );
174
175 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
177 let layers = self
178 .get_layers()
179 .0
180 .into_iter()
181 .map(|(layer, _)| layer)
182 .collect::<Vec<_>>();
183 for layer in layers {
184 Arc::get_mut(layer).unwrap().begin_track_stats()?;
185 }
186 Ok(())
187 }
188
189 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
191 let layers = self
192 .get_layers()
193 .0
194 .into_iter()
195 .enumerate()
196 .map(|(i, (layer, _))| (i, layer))
197 .collect::<Vec<_>>();
198 let mut data = HashMap::new();
199 for (i, layer) in layers {
200 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
201 }
202 Ok(CollectedImatrixData(data))
203 }
204
205 #[allow(clippy::type_complexity)]
208 fn get_layers_moe_experts_only(
209 &mut self,
210 ) -> (
211 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
212 &dyn DeviceMapper,
213 ) {
214 self.get_layers()
215 }
216
217 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
220 let layers = self
221 .get_layers()
222 .0
223 .into_iter()
224 .map(|(layer, _)| layer)
225 .collect::<Vec<_>>();
226 for layer in layers {
227 Arc::get_mut(layer).unwrap().begin_track_stats()?;
228 }
229 Ok(())
230 }
231
232 fn extract_imatrix_data_moe_experts_only(
235 &mut self,
236 ) -> candle_core::Result<CollectedImatrixData> {
237 let layers = self
238 .get_layers()
239 .0
240 .into_iter()
241 .enumerate()
242 .map(|(i, (layer, _))| (i, layer))
243 .collect::<Vec<_>>();
244 let mut data = HashMap::new();
245 for (i, layer) in layers {
246 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
247 }
248 Ok(CollectedImatrixData(data))
249 }
250
251 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
258 candle_core::bail!("This model does not support quantizing with an imatrix.");
260 }
261
262 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
264
265 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
267 None
268 }
269
270 #[allow(clippy::too_many_arguments)]
275 fn quantize(
276 &mut self,
277 dtype: Option<IsqType>,
278 device: Device,
279 topology: Option<&Topology>,
280 silent: bool,
281 imatrix_source: Option<ImatrixDataSource<'_>>,
282 organization: IsqOrganization,
283 write_artifacts: Option<&PathBuf>,
284 full_ser: UqffFullSer<'_>,
285 multi_progress: Arc<MultiProgress>,
286 ) -> candle_core::Result<()> {
287 {
288 let imatrix_to_weight = match imatrix_source {
289 Some(ImatrixDataSource::File(imatrix)) => {
290 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
291 "Expected an extension for the imatrix source file.",
292 ))?;
293 if ext == "cimatrix" {
294 info!(
295 "Loading collected imatrix source file: `{}`",
296 imatrix.display()
297 );
298 Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
299 } else if ext == "imatrix" {
300 info!(
301 "Loading GGUF-format imatrix source file: `{}`",
302 imatrix.display()
303 );
304 let mut imatrix_data =
305 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
306 let imatrix_mapping = self
307 .imatrix_names()?
308 .into_iter()
309 .enumerate()
310 .collect::<HashMap<_, _>>();
311
312 let layer_to_weight = imatrix_mapping
313 .into_iter()
314 .map(|(i, name)| {
315 if let Some(name) = name {
316 (i, Some(imatrix_data.remove(&name).unwrap()))
317 } else {
318 (i, None)
319 }
320 })
321 .collect::<HashMap<_, _>>();
322 info!(
323 "Quantizing with imatrix file `{}`, {} imatrix weights",
324 imatrix.display(),
325 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
326 );
327 Some(layer_to_weight)
328 } else {
329 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
330 info!(
331 "Loading GGUF-format imatrix source file: `{}`",
332 imatrix.display()
333 );
334
335 let mut imatrix_data =
336 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
337 let imatrix_mapping = self
338 .imatrix_names()?
339 .into_iter()
340 .enumerate()
341 .collect::<HashMap<_, _>>();
342
343 let layer_to_weight = imatrix_mapping
344 .into_iter()
345 .map(|(i, name)| {
346 if let Some(name) = name {
347 (i, Some(imatrix_data.remove(&name).unwrap()))
348 } else {
349 (i, None)
350 }
351 })
352 .collect::<HashMap<_, _>>();
353 info!(
354 "Quantizing with imatrix file `{}`, {} imatrix weights",
355 imatrix.display(),
356 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
357 );
358 Some(layer_to_weight)
359 }
360 }
361 Some(ImatrixDataSource::Collected) => {
362 let data = match organization {
363 IsqOrganization::Default => self.extract_imatrix_data()?,
364 IsqOrganization::MoeExpertsOnly => {
365 self.extract_imatrix_data_moe_experts_only()?
366 }
367 };
368 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
370 let save_path = format!("collected-{count}.cimatrix");
371 info!("Saving collected imatrix data to `{save_path}`");
372 data.save_imatrix(save_path)?;
373 info!("Quantizing with collected imatrix data, {count} imatrix weights");
374 Some(data.0)
375 }
376 None => {
377 None
379 }
380 };
381
382 let (mut tensors, mapper) = match organization {
383 IsqOrganization::Default => self.get_layers(),
384 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
385 };
386
387 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
388 if let Some(mut imatrix_to_weight) = imatrix_to_weight {
389 let ordered_keys = imatrix_to_weight
390 .keys()
391 .copied()
392 .sorted()
393 .collect::<Vec<_>>();
394 ordered_keys
395 .into_iter()
396 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
397 .collect()
398 } else {
399 vec![None; tensors.len()]
400 };
401
402 let total_tensors = tensors.len();
403 let n_quantized = AtomicUsize::new(0);
404 if let Some(topology) = topology {
405 let mut dtypes = HashSet::new();
406 for layer in topology.0.iter().flatten() {
407 if let LayerTopology {
408 isq: Some(isq_dtype),
409 device: _,
410 } = layer
411 {
412 dtypes.insert(isq_dtype);
413 }
414 }
415 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
416 } else {
417 info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
418 }
419 let bar = ProgressBar::new(total_tensors as u64);
420 bar.set_style(
421 ProgressStyle::default_bar()
422 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
423 .unwrap()
424 .progress_chars("#>-"),
425 );
426 multi_progress.add(bar.clone());
427
428 let layers = topology.map(|x| {
429 x.0.iter()
430 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
431 .collect::<Vec<_>>()
432 });
433
434 let mut devices_and_dtypes = Vec::new();
435 for (_, layer_num) in &tensors {
436 let device = if let Some(ref layers) = layers {
437 if let Some(layer) = layer_num {
438 layers
439 .get(*layer)
440 .as_ref()
441 .map(|x| x.1.clone())
442 .unwrap_or(Some(device.clone()))
443 .unwrap_or(device.clone())
444 } else {
445 device.clone()
446 }
447 } else if let Some(layer_num) = layer_num {
448 mapper
449 .device_for(*layer_num, false)
450 .cloned()
451 .unwrap_or(device.clone())
452 } else {
453 device.clone()
454 };
455 let dtype = if let Some(ref layers) = layers {
456 if let Some(layer) = layer_num {
457 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
458 } else {
459 dtype
460 }
461 } else {
462 dtype
463 };
464 devices_and_dtypes.push((device, dtype));
465 }
466
467 let t_start = Instant::now();
468
469 use rayon::iter::IntoParallelRefIterator;
470
471 let mut minimum_max_threads = {
473 let current_rayon_threads = rayon::current_num_threads();
474 if let Some(dtype) = dtype {
475 dtype
476 .get_max_isq_cpu_threads()
477 .map(usize::from)
478 .unwrap_or(current_rayon_threads)
479 } else {
480 current_rayon_threads
481 }
482 };
483 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
484 minimum_max_threads = 1;
485 }
486
487 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
488 minimum_max_threads = 1;
490 }
491
492 info!("Applying ISQ on {minimum_max_threads} threads.");
493
494 let pool = rayon::ThreadPoolBuilder::new()
495 .num_threads(minimum_max_threads)
496 .build()
497 .map_err(candle_core::Error::msg)?;
498
499 let guard = QuantizeOntoGuard::new();
500
501 pool.install(|| {
502 use indicatif::ParallelProgressIterator;
503 use rayon::iter::{
504 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
505 };
506 if silent {
507 tensors
508 .par_iter_mut()
509 .zip(devices_and_dtypes)
510 .zip(imatrix_to_weight)
511 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
512 **tensor = tensor
513 .clone()
514 .apply_isq(
515 dtype,
516 device.clone(),
517 &n_quantized,
518 imatrix_weight,
519 guard.clone(),
520 )
521 .unwrap();
522 device.synchronize().unwrap();
523 });
524 } else {
525 tensors
526 .par_iter_mut()
527 .zip(devices_and_dtypes)
528 .zip(imatrix_to_weight)
529 .progress_with(bar)
530 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
531 **tensor = tensor
532 .clone()
533 .apply_isq(
534 dtype,
535 device.clone(),
536 &n_quantized,
537 imatrix_weight,
538 guard.clone(),
539 )
540 .unwrap();
541 device.synchronize().unwrap();
542 });
543 }
544 });
545
546 if let Some(serialized) = write_artifacts {
547 info!(
548 "Serializing {total_tensors} ISQ tensors to `{}`.",
549 serialized.display()
550 );
551
552 if serialized.extension().is_none_or(|ext| ext != "uqff") {
553 candle_core::bail!("UQFF output path extension must be `.uqff`",);
554 }
555
556 let bar = ProgressBar::new(total_tensors as u64);
557 bar.set_style(
558 ProgressStyle::default_bar()
559 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
560 .unwrap()
561 .progress_chars("#>-"),
562 );
563
564 #[cfg(not(feature = "metal"))]
565 let n_threads = 2;
566 #[cfg(feature = "metal")]
567 let n_threads = 1;
568
569 let pool = rayon::ThreadPoolBuilder::new()
570 .num_threads(n_threads)
571 .build()
572 .map_err(candle_core::Error::msg)?;
573
574 let quantized_values = pool.install(|| {
575 if silent {
576 tensors
577 .par_iter()
578 .enumerate()
579 .filter(|(_, (layer, _))| layer.isq_serde_supported())
580 .map(|(i, (layer, _))| {
581 Ok((
582 i.to_string(),
583 Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
584 ))
585 })
586 .collect::<candle_core::Result<Vec<_>>>()
587 } else {
588 tensors
589 .par_iter()
590 .enumerate()
591 .progress_with(bar)
592 .filter(|(_, (layer, _))| layer.isq_serde_supported())
593 .map(|(i, (layer, _))| {
594 Ok((
595 i.to_string(),
596 Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
597 ))
598 })
599 .collect::<candle_core::Result<Vec<_>>>()
600 }
601 });
602 let quantized_values = quantized_values?;
603
604 let parent = serialized
605 .parent()
606 .context("Target UQFF path must have a filename!")?;
607
608 std::fs::create_dir_all(parent)?;
609
610 let file_stem = serialized
611 .file_stem()
612 .context("Target UQFF path must have a file stem!")?
613 .to_string_lossy()
614 .to_string();
615
616 let size_estimate_bytes = quantized_values
617 .iter()
618 .map(|(_, x)| x.elem_count() * x.dtype().size_in_bytes())
619 .sum::<usize>();
620 let n_files = size_estimate_bytes.div_ceil(MAX_UQFF_SIZE_BYTES);
621
622 if n_files == 1 {
623 info!("Writing to `{}`", serialized.display());
624 safetensors::serialize_to_file(quantized_values, &None, serialized)?;
625 } else {
626 let chunksize = quantized_values.len() / n_files;
627 let quantized_values_chunks = quantized_values.into_iter().chunks(chunksize);
628 for (i, chunk) in quantized_values_chunks.into_iter().enumerate() {
629 let mut name = parent.to_path_buf();
630 name.push(format!("{file_stem}-{i}.uqff"));
631 info!("Writing shard {i} to `{}`", name.display());
632 safetensors::serialize_to_file(chunk, &None, &name)?;
633 }
634 }
635
636 let residual = match organization {
637 IsqOrganization::Default => self.residual_tensors(),
638 IsqOrganization::MoeExpertsOnly => self
639 .residual_tensors_moe_experts_only()
640 .unwrap_or(self.residual_tensors()),
641 };
642
643 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
644 let config_out = parent.join("config.json");
645 let tokenizer_out = parent.join("tokenizer.json");
646 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
647 let gen_cfg_out = parent.join("generation_config.json");
648 let processor_out = parent.join("processor_config.json");
649 let preprocessor_out = parent.join("preprocessor_config.json");
650
651 info!(
652 "Serializing {} residual tensors to `{}`.",
653 residual.len(),
654 residual_out.display()
655 );
656
657 safetensors::serialize_to_file(residual, &None, &residual_out)?;
658
659 let UqffFullSer {
660 tokenizer,
661 template_filename,
662 generation_config,
663 config,
664 processor_filename,
665 preprocessor_filename,
666 } = full_ser;
667
668 info!("Serializing configuration to `{}`.", config_out.display());
669
670 std::fs::write(config_out, config)?;
671
672 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
673
674 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
675 .map_err(candle_core::Error::msg)?;
676
677 if let Some(template_filename) = template_filename {
678 info!(
679 "Serializing tokenizer config to `{}`.",
680 tokenizer_cfg_out.display()
681 );
682
683 let template =
684 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
685 std::fs::write(&tokenizer_cfg_out, template)
686 .map_err(candle_core::Error::msg)?;
687 }
688
689 if let Some(generation_config) = generation_config {
690 info!(
691 "Serializing generation config to `{}`.",
692 gen_cfg_out.display()
693 );
694
695 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
696 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
697 }
698
699 if let Some(processor_config) = processor_filename {
700 info!(
701 "Serializing processor config to `{}`.",
702 processor_out.display()
703 );
704
705 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
706 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
707 }
708
709 if let Some(preprocessor_config) = preprocessor_filename {
710 info!(
711 "Serializing preprocessor config to `{}`.",
712 preprocessor_out.display()
713 );
714
715 let cfg =
716 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
717 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
718 }
719 }
720 let delta = Instant::now().duration_since(t_start).as_secs_f32();
721 info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
722 }
723 Ok(())
724 }
725
726 fn load_from_artifacts(
727 &mut self,
728 device: Device,
729 topology: Option<&Topology>,
730 silent: bool,
731 artifacts: &[PathBuf],
732 ) -> candle_core::Result<()> {
733 let (tensors, mapper) = self.get_layers();
734 let total_tensors = tensors.len();
735
736 let layers = topology.map(|x| {
737 x.0.iter()
738 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
739 .collect::<Vec<_>>()
740 });
741
742 let mut devices = Vec::new();
743 let mut comms = Vec::new();
744 for (_, layer_num) in &tensors {
745 let device = if let Some(ref layers) = layers {
746 if let Some(layer) = layer_num {
747 layers
748 .get(*layer)
749 .as_ref()
750 .map(|x| x.1.clone())
751 .unwrap_or(Some(device.clone()))
752 .unwrap_or(device.clone())
753 } else {
754 device.clone()
755 }
756 } else if let Some(layer_num) = layer_num {
757 mapper
758 .device_for(*layer_num, false)
759 .cloned()
760 .unwrap_or(device.clone())
761 } else {
762 device.clone()
763 };
764 devices.push(device);
765 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
766 }
767
768 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
769
770 let artifact_isqs = artifacts
771 .tensors()
772 .into_iter()
773 .map(|(name, tensor)| {
774 (
775 name.parse::<usize>()
776 .expect("Name should be parseable as usize"),
777 tensor,
778 )
779 })
780 .collect::<HashMap<_, _>>();
781
782 if artifact_isqs.len() != total_tensors {
783 candle_core::bail!(
784 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
785 artifact_isqs.len(),
786 );
787 }
788
789 let bar = ProgressBar::new(total_tensors as u64);
790 bar.set_style(
791 ProgressStyle::default_bar()
792 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
793 .unwrap()
794 .progress_chars("#>-"),
795 );
796
797 let t_start = Instant::now();
798
799 let guard = QuantizeOntoGuard::new();
800
801 if silent {
802 (0..tensors.len())
803 .into_par_iter()
804 .zip(tensors)
805 .map(|(i, (tensor, _))| {
806 if let Some(artifact) = artifact_isqs.get(&i) {
807 let artifact = artifact.data();
808
809 let comm = comms[i].clone();
810 let deserialized = match tensor.is_distributed() {
811 Some(DistributedKind::ColumnParallel) => {
812 ColumnParallelLayer::deserialize(
813 Cow::from(artifact),
814 &devices[i],
815 &comm,
816 guard.clone(),
817 )?
818 }
819 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
820 Cow::from(artifact),
821 &devices[i],
822 &comm,
823 guard.clone(),
824 )?,
825 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
826 Cow::from(artifact),
827 &devices[i],
828 &comm,
829 guard.clone(),
830 )?,
831 None => {
832 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
834 match QuantizedSerdeType::try_from(isq_type as usize)? {
835 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
836 Cow::from(artifact),
837 &devices[i],
838 &comm,
839 guard.clone(),
840 )?,
841 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
842 Cow::from(artifact),
843 &devices[i],
844 &comm,
845 guard.clone(),
846 )?,
847 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
848 Cow::from(artifact),
849 &devices[i],
850 &comm,
851 guard.clone(),
852 )?,
853 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
854 Cow::from(artifact),
855 &devices[i],
856 &comm,
857 guard.clone(),
858 )?,
859 QuantizedSerdeType::Afq => AfqLayer::deserialize(
860 Cow::from(artifact),
861 &devices[i],
862 &comm,
863 guard.clone(),
864 )?,
865 }
866 }
867 };
868 *tensor = deserialized;
869 }
870 Ok(())
871 })
872 .collect::<candle_core::Result<Vec<_>>>()?;
873 } else {
874 (0..tensors.len())
875 .into_par_iter()
876 .zip(tensors)
877 .progress_with(bar)
878 .map(|(i, (tensor, _))| {
879 if let Some(artifact) = artifact_isqs.get(&i) {
880 let artifact = artifact.data();
881
882 let comm = comms[i].clone();
883 let deserialized = match tensor.is_distributed() {
884 Some(DistributedKind::ColumnParallel) => {
885 ColumnParallelLayer::deserialize(
886 Cow::from(artifact),
887 &devices[i],
888 &comm,
889 guard.clone(),
890 )?
891 }
892 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
893 Cow::from(artifact),
894 &devices[i],
895 &comm,
896 guard.clone(),
897 )?,
898 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
899 Cow::from(artifact),
900 &devices[i],
901 &comm,
902 guard.clone(),
903 )?,
904 None => {
905 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
907 match QuantizedSerdeType::try_from(isq_type as usize)? {
908 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
909 Cow::from(artifact),
910 &devices[i],
911 &comm,
912 guard.clone(),
913 )?,
914 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
915 Cow::from(artifact),
916 &devices[i],
917 &comm,
918 guard.clone(),
919 )?,
920 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
921 Cow::from(artifact),
922 &devices[i],
923 &comm,
924 guard.clone(),
925 )?,
926 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
927 Cow::from(artifact),
928 &devices[i],
929 &comm,
930 guard.clone(),
931 )?,
932 QuantizedSerdeType::Afq => AfqLayer::deserialize(
933 Cow::from(artifact),
934 &devices[i],
935 &comm,
936 guard.clone(),
937 )?,
938 }
939 }
940 };
941 *tensor = deserialized;
942 }
943 Ok(())
944 })
945 .collect::<candle_core::Result<Vec<_>>>()?;
946 }
947
948 let delta = Instant::now().duration_since(t_start).as_secs_f32();
949 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
950
951 Ok(())
952 }
953}
954
955pub(crate) trait IsqModelLoader {
957 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
961 Ok(Vec::new())
962 }
963
964 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
968 self.isq_layer_regexes(config)
969 }
970}