1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11#[derive(Clone)]
19pub struct CowBytesView<'a> {
20 data: Cow<'a, [u8]>,
21 shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25 pub fn new(data: Cow<'a, [u8]>) -> Self {
27 let len = data.len();
28 Self { data, shape: [len] }
29 }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33 fn dtype(&self) -> safetensors::tensor::Dtype {
34 safetensors::tensor::Dtype::U8
36 }
37
38 fn shape(&self) -> &[usize] {
39 &self.shape
40 }
41
42 fn data(&self) -> Cow<'_, [u8]> {
43 assert!(matches!(self.data, Cow::Borrowed(_)));
44 self.data.clone()
46 }
47
48 fn data_len(&self) -> usize {
49 self.data.len()
50 }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60 ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
69
70pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
71const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
73pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
74
75pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
104 let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
105 let tp = match s.to_lowercase().as_str() {
106 "2" if is_metal => IsqType::AFQ2,
107 "2" if !is_metal => IsqType::Q2K,
108 "3" if is_metal => IsqType::AFQ3,
109 "3" if !is_metal => IsqType::Q3K,
110 "4" if is_metal => IsqType::AFQ4,
111 "4" if !is_metal => IsqType::Q4K,
112 "5" => IsqType::Q5K,
113 "6" if is_metal => IsqType::AFQ6,
114 "6" if !is_metal => IsqType::Q6K,
115 "8" if is_metal => IsqType::AFQ8,
116 "8" if !is_metal => IsqType::Q8_0,
117 "q4_0" => IsqType::Q4_0,
118 "q4_1" => IsqType::Q4_1,
119 "q5_0" => IsqType::Q5_0,
120 "q5_1" => IsqType::Q5_1,
121 "q8_0" => IsqType::Q8_0,
122 "q8_1" => IsqType::Q8_1,
123 "q2k" => IsqType::Q2K,
124 "q3k" => IsqType::Q3K,
125 "q4k" => IsqType::Q4K,
126 "q5k" => IsqType::Q5K,
127 "q6k" => IsqType::Q6K,
128 "q8k" => IsqType::Q8K,
129 "hqq8" => IsqType::HQQ8,
130 "hqq4" => IsqType::HQQ4,
131 "fp8" => IsqType::F8E4M3,
132 "afq8" => IsqType::AFQ8,
133 "afq6" => IsqType::AFQ6,
134 "afq4" => IsqType::AFQ4,
135 "afq3" => IsqType::AFQ3,
136 "afq2" => IsqType::AFQ2,
137 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
141 };
142 #[cfg(feature = "cuda")]
143 {
144 if !matches!(
145 tp,
146 IsqType::Q4_0
147 | IsqType::Q4_1
148 | IsqType::Q5_0
149 | IsqType::Q5_1
150 | IsqType::Q8_0
151 | IsqType::Q2K
152 | IsqType::Q3K
153 | IsqType::Q4K
154 | IsqType::Q5K
155 | IsqType::Q6K
156 | IsqType::HQQ8
157 | IsqType::HQQ4
158 | IsqType::F8E4M3 ) {
162 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
163 }
164 }
165 Ok(tp)
166}
167
168#[derive(Clone, Debug, Copy, Default, Deserialize)]
169pub enum IsqOrganization {
170 #[default]
171 #[serde(rename = "default")]
172 Default,
173 #[serde(rename = "moqe")]
176 MoeExpertsOnly,
177}
178
179impl FromStr for IsqOrganization {
180 type Err = String;
181 fn from_str(s: &str) -> Result<Self, Self::Err> {
182 match s {
183 "default" => Ok(Self::Default),
184 "moqe" => Ok(Self::MoeExpertsOnly),
185 other => Err(format!(
186 "Expected ISQ organization `default` or `moqe`, got `{other}`"
187 )),
188 }
189 }
190}
191
192pub struct UqffFullSer<'a> {
193 pub tokenizer: &'a Tokenizer,
194 pub template_filename: &'a Option<PathBuf>,
195 pub generation_config: Option<&'a PathBuf>,
196 pub config: String,
197 pub processor_filename: &'a Option<PathBuf>,
198 pub preprocessor_filename: &'a Option<PathBuf>,
199}
200
201#[derive(Debug, Clone, Copy)]
202pub enum ImatrixDataSource<'a> {
203 File(&'a PathBuf),
204 Collected,
205}
206
207pub trait IsqModel {
208 #[allow(clippy::type_complexity)]
210 fn get_layers(
211 &mut self,
212 ) -> (
213 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
214 &dyn DeviceMapper,
215 );
216
217 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
219 let layers = self
220 .get_layers()
221 .0
222 .into_iter()
223 .map(|(layer, _)| layer)
224 .collect::<Vec<_>>();
225 for layer in layers {
226 Arc::get_mut(layer).unwrap().begin_track_stats()?;
227 }
228 Ok(())
229 }
230
231 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
233 let layers = self
234 .get_layers()
235 .0
236 .into_iter()
237 .enumerate()
238 .map(|(i, (layer, _))| (i, layer))
239 .collect::<Vec<_>>();
240 let mut data = HashMap::new();
241 for (i, layer) in layers {
242 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
243 }
244 Ok(CollectedImatrixData(data))
245 }
246
247 #[allow(clippy::type_complexity)]
250 fn get_layers_moe_experts_only(
251 &mut self,
252 ) -> (
253 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
254 &dyn DeviceMapper,
255 ) {
256 self.get_layers()
257 }
258
259 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
262 let layers = self
263 .get_layers()
264 .0
265 .into_iter()
266 .map(|(layer, _)| layer)
267 .collect::<Vec<_>>();
268 for layer in layers {
269 Arc::get_mut(layer).unwrap().begin_track_stats()?;
270 }
271 Ok(())
272 }
273
274 fn extract_imatrix_data_moe_experts_only(
277 &mut self,
278 ) -> candle_core::Result<CollectedImatrixData> {
279 let layers = self
280 .get_layers()
281 .0
282 .into_iter()
283 .enumerate()
284 .map(|(i, (layer, _))| (i, layer))
285 .collect::<Vec<_>>();
286 let mut data = HashMap::new();
287 for (i, layer) in layers {
288 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
289 }
290 Ok(CollectedImatrixData(data))
291 }
292
293 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
300 candle_core::bail!("This model does not support quantizing with an imatrix.");
302 }
303
304 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
306
307 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
309 None
310 }
311
312 #[allow(clippy::too_many_arguments)]
317 fn quantize(
318 &mut self,
319 dtype: Option<IsqType>,
320 device: Device,
321 topology: Option<&Topology>,
322 silent: bool,
323 imatrix_source: Option<ImatrixDataSource<'_>>,
324 organization: IsqOrganization,
325 write_artifacts: Option<&PathBuf>,
326 full_ser: UqffFullSer<'_>,
327 multi_progress: Arc<MultiProgress>,
328 ) -> candle_core::Result<()> {
329 {
330 let imatrix_to_weight = match imatrix_source {
331 Some(ImatrixDataSource::File(imatrix)) => {
332 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
333 "Expected an extension for the imatrix source file.",
334 ))?;
335 if ext == "cimatrix" {
336 info!(
337 "Loading collected imatrix source file: `{}`",
338 imatrix.display()
339 );
340 Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
341 } else if ext == "imatrix" {
342 info!(
343 "Loading GGUF-format imatrix source file: `{}`",
344 imatrix.display()
345 );
346 let mut imatrix_data =
347 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
348 let imatrix_mapping = self
349 .imatrix_names()?
350 .into_iter()
351 .enumerate()
352 .collect::<HashMap<_, _>>();
353
354 let layer_to_weight = imatrix_mapping
355 .into_iter()
356 .map(|(i, name)| {
357 if let Some(name) = name {
358 (i, Some(imatrix_data.remove(&name).unwrap()))
359 } else {
360 (i, None)
361 }
362 })
363 .collect::<HashMap<_, _>>();
364 info!(
365 "Quantizing with imatrix file `{}`, {} imatrix weights",
366 imatrix.display(),
367 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
368 );
369 Some(layer_to_weight)
370 } else {
371 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
372 info!(
373 "Loading GGUF-format imatrix source file: `{}`",
374 imatrix.display()
375 );
376
377 let mut imatrix_data =
378 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
379 let imatrix_mapping = self
380 .imatrix_names()?
381 .into_iter()
382 .enumerate()
383 .collect::<HashMap<_, _>>();
384
385 let layer_to_weight = imatrix_mapping
386 .into_iter()
387 .map(|(i, name)| {
388 if let Some(name) = name {
389 (i, Some(imatrix_data.remove(&name).unwrap()))
390 } else {
391 (i, None)
392 }
393 })
394 .collect::<HashMap<_, _>>();
395 info!(
396 "Quantizing with imatrix file `{}`, {} imatrix weights",
397 imatrix.display(),
398 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
399 );
400 Some(layer_to_weight)
401 }
402 }
403 Some(ImatrixDataSource::Collected) => {
404 let data = match organization {
405 IsqOrganization::Default => self.extract_imatrix_data()?,
406 IsqOrganization::MoeExpertsOnly => {
407 self.extract_imatrix_data_moe_experts_only()?
408 }
409 };
410 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
412 let save_path = format!("collected-{count}.cimatrix");
413 info!("Saving collected imatrix data to `{save_path}`");
414 data.save_imatrix(save_path)?;
415 info!("Quantizing with collected imatrix data, {count} imatrix weights");
416 Some(data.0)
417 }
418 None => {
419 None
421 }
422 };
423
424 let (mut tensors, mapper) = match organization {
425 IsqOrganization::Default => self.get_layers(),
426 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
427 };
428
429 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
430 if let Some(mut imatrix_to_weight) = imatrix_to_weight {
431 let ordered_keys = imatrix_to_weight
432 .keys()
433 .copied()
434 .sorted()
435 .collect::<Vec<_>>();
436 ordered_keys
437 .into_iter()
438 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
439 .collect()
440 } else {
441 vec![None; tensors.len()]
442 };
443
444 let total_tensors = tensors.len();
445 let n_quantized = AtomicUsize::new(0);
446 if let Some(topology) = topology {
447 let mut dtypes = HashSet::new();
448 for layer in topology.0.iter().flatten() {
449 if let LayerTopology {
450 isq: Some(isq_dtype),
451 device: _,
452 } = layer
453 {
454 dtypes.insert(isq_dtype);
455 }
456 }
457 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
458 } else {
459 info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
460 }
461 let bar = ProgressBar::new(total_tensors as u64);
462 bar.set_style(
463 ProgressStyle::default_bar()
464 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
465 .unwrap()
466 .progress_chars("#>-"),
467 );
468 multi_progress.add(bar.clone());
469
470 let layers = topology.map(|x| {
471 x.0.iter()
472 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
473 .collect::<Vec<_>>()
474 });
475
476 let mut devices_and_dtypes = Vec::new();
477 for (_, layer_num) in &tensors {
478 let device = if let Some(ref layers) = layers {
479 if let Some(layer) = layer_num {
480 layers
481 .get(*layer)
482 .as_ref()
483 .map(|x| x.1.clone())
484 .unwrap_or(Some(device.clone()))
485 .unwrap_or(device.clone())
486 } else {
487 device.clone()
488 }
489 } else if let Some(layer_num) = layer_num {
490 mapper
491 .device_for(*layer_num, false)
492 .cloned()
493 .unwrap_or(device.clone())
494 } else {
495 device.clone()
496 };
497 let dtype = if let Some(ref layers) = layers {
498 if let Some(layer) = layer_num {
499 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
500 } else {
501 dtype
502 }
503 } else {
504 dtype
505 };
506 devices_and_dtypes.push((device, dtype));
507 }
508
509 let t_start = Instant::now();
510
511 use rayon::iter::IntoParallelRefIterator;
512
513 let mut minimum_max_threads = {
515 let current_rayon_threads = rayon::current_num_threads();
516 if let Some(dtype) = dtype {
517 dtype
518 .get_max_isq_cpu_threads()
519 .map(usize::from)
520 .unwrap_or(current_rayon_threads)
521 } else {
522 current_rayon_threads
523 }
524 };
525 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
526 minimum_max_threads = 1;
527 }
528
529 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
530 minimum_max_threads = 1;
532 }
533
534 info!("Applying ISQ on {minimum_max_threads} threads.");
535
536 let pool = rayon::ThreadPoolBuilder::new()
537 .num_threads(minimum_max_threads)
538 .build()
539 .map_err(candle_core::Error::msg)?;
540
541 let guard = QuantizeOntoGuard::new();
542
543 pool.install(|| {
544 use indicatif::ParallelProgressIterator;
545 use rayon::iter::{
546 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
547 };
548 if silent {
549 tensors
550 .par_iter_mut()
551 .zip(devices_and_dtypes)
552 .zip(imatrix_to_weight)
553 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
554 **tensor = tensor
555 .clone()
556 .apply_isq(
557 dtype,
558 device.clone(),
559 &n_quantized,
560 imatrix_weight,
561 guard.clone(),
562 )
563 .unwrap();
564 device.synchronize().unwrap();
565 });
566 } else {
567 tensors
568 .par_iter_mut()
569 .zip(devices_and_dtypes)
570 .zip(imatrix_to_weight)
571 .progress_with(bar)
572 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
573 **tensor = tensor
574 .clone()
575 .apply_isq(
576 dtype,
577 device.clone(),
578 &n_quantized,
579 imatrix_weight,
580 guard.clone(),
581 )
582 .unwrap();
583 device.synchronize().unwrap();
584 });
585 }
586 });
587
588 if let Some(serialized) = write_artifacts {
589 info!(
590 "Serializing {total_tensors} ISQ tensors to `{}`.",
591 serialized.display()
592 );
593
594 if serialized.extension().is_none_or(|ext| ext != "uqff") {
595 candle_core::bail!("UQFF output path extension must be `.uqff`",);
596 }
597
598 let bar = ProgressBar::new(total_tensors as u64);
599 bar.set_style(
600 ProgressStyle::default_bar()
601 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
602 .unwrap()
603 .progress_chars("#>-"),
604 );
605
606 #[cfg(not(feature = "metal"))]
607 let n_threads = 2;
608 #[cfg(feature = "metal")]
609 let n_threads = 1;
610
611 let pool = rayon::ThreadPoolBuilder::new()
612 .num_threads(n_threads)
613 .build()
614 .map_err(candle_core::Error::msg)?;
615
616 let quantized_values = pool.install(|| {
617 if silent {
618 tensors
619 .par_iter()
620 .enumerate()
621 .filter(|(_, (layer, _))| layer.isq_serde_supported())
622 .map(|(i, (layer, _))| {
623 Ok((
624 i.to_string(),
625 match layer.serialize()? {
626 Cow::Borrowed(_) => unreachable!(),
627 Cow::Owned(owned) => owned,
628 },
629 ))
630 })
631 .collect::<candle_core::Result<Vec<_>>>()
632 } else {
633 tensors
634 .par_iter()
635 .enumerate()
636 .progress_with(bar)
637 .filter(|(_, (layer, _))| layer.isq_serde_supported())
638 .map(|(i, (layer, _))| {
639 Ok((
640 i.to_string(),
641 match layer.serialize()? {
642 Cow::Borrowed(_) => unreachable!(),
643 Cow::Owned(owned) => owned,
644 },
645 ))
646 })
647 .collect::<candle_core::Result<Vec<_>>>()
648 }
649 });
650 let quantized_values = quantized_values?;
651
652 let parent = serialized
653 .parent()
654 .context("Target UQFF path must have a filename!")?;
655
656 std::fs::create_dir_all(parent)?;
657
658 let file_stem = serialized
659 .file_stem()
660 .context("Target UQFF path must have a file stem!")?
661 .to_string_lossy()
662 .to_string();
663
664 let mut current_chunk = Vec::new();
666 let mut current_bytes: usize = 0;
667 let mut shard_index = 0;
668
669 for (name, tensor) in quantized_values.iter() {
671 let tensor_bytes = tensor.len();
672 if !current_chunk.is_empty()
673 && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
674 {
675 let mut shard_path = parent.to_path_buf();
676 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
677 info!(
678 "Writing shard {} to `{}`",
679 shard_index,
680 shard_path.display()
681 );
682 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
683 shard_index += 1;
684 current_chunk.clear();
685 current_bytes = 0;
686 }
687 current_bytes += tensor_bytes;
688 current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
689 }
690
691 if !current_chunk.is_empty() {
692 let mut shard_path = parent.to_path_buf();
693 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
694 info!(
695 "Writing final shard {} to `{}`",
696 shard_index,
697 shard_path.display()
698 );
699 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
700 }
701
702 let residual = match organization {
703 IsqOrganization::Default => self.residual_tensors(),
704 IsqOrganization::MoeExpertsOnly => self
705 .residual_tensors_moe_experts_only()
706 .unwrap_or(self.residual_tensors()),
707 };
708
709 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
710 let config_out = parent.join("config.json");
711 let tokenizer_out = parent.join("tokenizer.json");
712 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
713 let chat_template_jinja_out = parent.join("chat_template.jinja");
714 let gen_cfg_out = parent.join("generation_config.json");
715 let processor_out = parent.join("processor_config.json");
716 let preprocessor_out = parent.join("preprocessor_config.json");
717
718 info!(
719 "Serializing {} residual tensors to `{}`.",
720 residual.len(),
721 residual_out.display()
722 );
723
724 safetensors::serialize_to_file(residual, None, &residual_out)?;
725
726 let UqffFullSer {
727 tokenizer,
728 template_filename,
729 generation_config,
730 config,
731 processor_filename,
732 preprocessor_filename,
733 } = full_ser;
734
735 info!("Serializing configuration to `{}`.", config_out.display());
736
737 std::fs::write(config_out, config)?;
738
739 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
740
741 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
742 .map_err(candle_core::Error::msg)?;
743
744 if let Some(template_filename) = template_filename {
745 let template =
746 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
747
748 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
749 info!(
750 "Serializing chat template to `{}`.",
751 chat_template_jinja_out.display()
752 );
753 std::fs::write(&chat_template_jinja_out, template)
754 .map_err(candle_core::Error::msg)?;
755 } else {
756 info!(
757 "Serializing tokenizer config to `{}`.",
758 tokenizer_cfg_out.display()
759 );
760 std::fs::write(&tokenizer_cfg_out, template)
761 .map_err(candle_core::Error::msg)?;
762 }
763 }
764
765 if let Some(generation_config) = generation_config {
766 info!(
767 "Serializing generation config to `{}`.",
768 gen_cfg_out.display()
769 );
770
771 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
772 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
773 }
774
775 if let Some(processor_config) = processor_filename {
776 info!(
777 "Serializing processor config to `{}`.",
778 processor_out.display()
779 );
780
781 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
782 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
783 }
784
785 if let Some(preprocessor_config) = preprocessor_filename {
786 info!(
787 "Serializing preprocessor config to `{}`.",
788 preprocessor_out.display()
789 );
790
791 let cfg =
792 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
793 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
794 }
795 }
796 let delta = Instant::now().duration_since(t_start).as_secs_f32();
797 info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
798 }
799 Ok(())
800 }
801
802 fn load_from_artifacts(
803 &mut self,
804 device: Device,
805 topology: Option<&Topology>,
806 silent: bool,
807 artifacts: &[PathBuf],
808 ) -> candle_core::Result<()> {
809 let (tensors, mapper) = self.get_layers();
810 let total_tensors = tensors.len();
811
812 let layers = topology.map(|x| {
813 x.0.iter()
814 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
815 .collect::<Vec<_>>()
816 });
817
818 let mut devices = Vec::new();
819 let mut comms = Vec::new();
820 for (_, layer_num) in &tensors {
821 let device = if let Some(ref layers) = layers {
822 if let Some(layer) = layer_num {
823 layers
824 .get(*layer)
825 .as_ref()
826 .map(|x| x.1.clone())
827 .unwrap_or(Some(device.clone()))
828 .unwrap_or(device.clone())
829 } else {
830 device.clone()
831 }
832 } else if let Some(layer_num) = layer_num {
833 mapper
834 .device_for(*layer_num, false)
835 .cloned()
836 .unwrap_or(device.clone())
837 } else {
838 device.clone()
839 };
840 devices.push(device);
841 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
842 }
843
844 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
845
846 let artifact_isqs = artifacts
847 .tensors()
848 .into_iter()
849 .map(|(name, tensor)| {
850 (
851 name.parse::<usize>()
852 .expect("Name should be parseable as usize"),
853 tensor,
854 )
855 })
856 .collect::<HashMap<_, _>>();
857
858 if artifact_isqs.len() != total_tensors {
859 candle_core::bail!(
860 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
861 artifact_isqs.len(),
862 );
863 }
864
865 let bar = ProgressBar::new(total_tensors as u64);
866 bar.set_style(
867 ProgressStyle::default_bar()
868 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
869 .unwrap()
870 .progress_chars("#>-"),
871 );
872
873 let t_start = Instant::now();
874
875 let guard = QuantizeOntoGuard::new();
876
877 if silent {
878 (0..tensors.len())
879 .into_par_iter()
880 .zip(tensors)
881 .map(|(i, (tensor, _))| {
882 if let Some(artifact) = artifact_isqs.get(&i) {
883 let artifact = artifact.data();
884
885 let comm = comms[i].clone();
886 let deserialized = match tensor.is_distributed() {
887 Some(DistributedKind::ColumnParallel) => {
888 ColumnParallelLayer::deserialize(
889 Cow::from(artifact),
890 &devices[i],
891 &comm,
892 guard.clone(),
893 )?
894 }
895 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
896 Cow::from(artifact),
897 &devices[i],
898 &comm,
899 guard.clone(),
900 )?,
901 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
902 Cow::from(artifact),
903 &devices[i],
904 &comm,
905 guard.clone(),
906 )?,
907 None => {
908 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
910 match QuantizedSerdeType::try_from(isq_type as usize)? {
911 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
912 Cow::from(artifact),
913 &devices[i],
914 &comm,
915 guard.clone(),
916 )?,
917 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
918 Cow::from(artifact),
919 &devices[i],
920 &comm,
921 guard.clone(),
922 )?,
923 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
924 Cow::from(artifact),
925 &devices[i],
926 &comm,
927 guard.clone(),
928 )?,
929 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
930 Cow::from(artifact),
931 &devices[i],
932 &comm,
933 guard.clone(),
934 )?,
935 QuantizedSerdeType::Afq => AfqLayer::deserialize(
936 Cow::from(artifact),
937 &devices[i],
938 &comm,
939 guard.clone(),
940 )?,
941 }
942 }
943 };
944 *tensor = deserialized;
945 }
946 Ok(())
947 })
948 .collect::<candle_core::Result<Vec<_>>>()?;
949 } else {
950 (0..tensors.len())
951 .into_par_iter()
952 .zip(tensors)
953 .progress_with(bar)
954 .map(|(i, (tensor, _))| {
955 if let Some(artifact) = artifact_isqs.get(&i) {
956 let artifact = artifact.data();
957
958 let comm = comms[i].clone();
959 let deserialized = match tensor.is_distributed() {
960 Some(DistributedKind::ColumnParallel) => {
961 ColumnParallelLayer::deserialize(
962 Cow::from(artifact),
963 &devices[i],
964 &comm,
965 guard.clone(),
966 )?
967 }
968 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
969 Cow::from(artifact),
970 &devices[i],
971 &comm,
972 guard.clone(),
973 )?,
974 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
975 Cow::from(artifact),
976 &devices[i],
977 &comm,
978 guard.clone(),
979 )?,
980 None => {
981 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
983 match QuantizedSerdeType::try_from(isq_type as usize)? {
984 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
985 Cow::from(artifact),
986 &devices[i],
987 &comm,
988 guard.clone(),
989 )?,
990 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
991 Cow::from(artifact),
992 &devices[i],
993 &comm,
994 guard.clone(),
995 )?,
996 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
997 Cow::from(artifact),
998 &devices[i],
999 &comm,
1000 guard.clone(),
1001 )?,
1002 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1003 Cow::from(artifact),
1004 &devices[i],
1005 &comm,
1006 guard.clone(),
1007 )?,
1008 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1009 Cow::from(artifact),
1010 &devices[i],
1011 &comm,
1012 guard.clone(),
1013 )?,
1014 }
1015 }
1016 };
1017 *tensor = deserialized;
1018 }
1019 Ok(())
1020 })
1021 .collect::<candle_core::Result<Vec<_>>>()?;
1022 }
1023
1024 let delta = Instant::now().duration_since(t_start).as_secs_f32();
1025 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1026
1027 Ok(())
1028 }
1029}
1030
1031pub(crate) trait IsqModelLoader {
1033 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1037 Ok(Vec::new())
1038 }
1039
1040 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1044 self.isq_layer_regexes(config)
1045 }
1046
1047 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1051 Ok(Vec::new())
1052 }
1053
1054 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1058 self.isq_layer_regexes(config)
1059 }
1060}