1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11#[derive(Clone)]
19pub struct CowBytesView<'a> {
20 data: Cow<'a, [u8]>,
21 shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25 pub fn new(data: Cow<'a, [u8]>) -> Self {
27 let len = data.len();
28 Self { data, shape: [len] }
29 }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33 fn dtype(&self) -> safetensors::tensor::Dtype {
34 safetensors::tensor::Dtype::U8
36 }
37
38 fn shape(&self) -> &[usize] {
39 &self.shape
40 }
41
42 fn data(&self) -> Cow<'_, [u8]> {
43 assert!(matches!(self.data, Cow::Borrowed(_)));
44 self.data.clone()
46 }
47
48 fn data_len(&self) -> usize {
49 self.data.len()
50 }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60 ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69 device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70 utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
76pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
77
78pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
107 let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
108 let tp = match s.to_lowercase().as_str() {
109 "2" if is_metal => IsqType::AFQ2,
110 "2" if !is_metal => IsqType::Q2K,
111 "3" if is_metal => IsqType::AFQ3,
112 "3" if !is_metal => IsqType::Q3K,
113 "4" if is_metal => IsqType::AFQ4,
114 "4" if !is_metal => IsqType::Q4K,
115 "5" => IsqType::Q5K,
116 "6" if is_metal => IsqType::AFQ6,
117 "6" if !is_metal => IsqType::Q6K,
118 "8" if is_metal => IsqType::AFQ8,
119 "8" if !is_metal => IsqType::Q8_0,
120 "q4_0" => IsqType::Q4_0,
121 "q4_1" => IsqType::Q4_1,
122 "q5_0" => IsqType::Q5_0,
123 "q5_1" => IsqType::Q5_1,
124 "q8_0" => IsqType::Q8_0,
125 "q8_1" => IsqType::Q8_1,
126 "q2k" => IsqType::Q2K,
127 "q3k" => IsqType::Q3K,
128 "q4k" => IsqType::Q4K,
129 "q5k" => IsqType::Q5K,
130 "q6k" => IsqType::Q6K,
131 "q8k" => IsqType::Q8K,
132 "hqq8" => IsqType::HQQ8,
133 "hqq4" => IsqType::HQQ4,
134 "fp8" => IsqType::F8E4M3,
135 "afq8" => IsqType::AFQ8,
136 "afq6" => IsqType::AFQ6,
137 "afq4" => IsqType::AFQ4,
138 "afq3" => IsqType::AFQ3,
139 "afq2" => IsqType::AFQ2,
140 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
144 };
145 #[cfg(feature = "cuda")]
146 {
147 if !matches!(
148 tp,
149 IsqType::Q4_0
150 | IsqType::Q4_1
151 | IsqType::Q5_0
152 | IsqType::Q5_1
153 | IsqType::Q8_0
154 | IsqType::Q2K
155 | IsqType::Q3K
156 | IsqType::Q4K
157 | IsqType::Q5K
158 | IsqType::Q6K
159 | IsqType::HQQ8
160 | IsqType::HQQ4
161 | IsqType::F8E4M3 ) {
165 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
166 }
167 }
168 Ok(tp)
169}
170
171#[derive(Clone, Debug, Copy, Default, Deserialize)]
172pub enum IsqOrganization {
173 #[default]
174 #[serde(rename = "default")]
175 Default,
176 #[serde(rename = "moqe")]
179 MoeExpertsOnly,
180}
181
182impl FromStr for IsqOrganization {
183 type Err = String;
184 fn from_str(s: &str) -> Result<Self, Self::Err> {
185 match s {
186 "default" => Ok(Self::Default),
187 "moqe" => Ok(Self::MoeExpertsOnly),
188 other => Err(format!(
189 "Expected ISQ organization `default` or `moqe`, got `{other}`"
190 )),
191 }
192 }
193}
194
195pub struct UqffFullSer<'a> {
196 pub tokenizer: &'a Tokenizer,
197 pub template_filename: &'a Option<PathBuf>,
198 pub modules: Option<&'a String>,
199 pub module_paths: Option<&'a [EmbeddingModulePaths]>,
200 pub generation_config: Option<&'a PathBuf>,
201 pub config: String,
202 pub processor_filename: &'a Option<PathBuf>,
203 pub preprocessor_filename: &'a Option<PathBuf>,
204}
205
206#[derive(Debug, Clone, Copy)]
207pub enum ImatrixDataSource<'a> {
208 File(&'a PathBuf),
209 Collected,
210}
211
212pub trait IsqModel {
213 #[allow(clippy::type_complexity)]
215 fn get_layers(
216 &mut self,
217 ) -> (
218 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
219 &dyn DeviceMapper,
220 );
221
222 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
224 let layers = self
225 .get_layers()
226 .0
227 .into_iter()
228 .map(|(layer, _)| layer)
229 .collect::<Vec<_>>();
230 for layer in layers {
231 Arc::get_mut(layer).unwrap().begin_track_stats()?;
232 }
233 Ok(())
234 }
235
236 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
238 let layers = self
239 .get_layers()
240 .0
241 .into_iter()
242 .enumerate()
243 .map(|(i, (layer, _))| (i, layer))
244 .collect::<Vec<_>>();
245 let mut data = HashMap::new();
246 for (i, layer) in layers {
247 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
248 }
249 Ok(CollectedImatrixData(data))
250 }
251
252 #[allow(clippy::type_complexity)]
255 fn get_layers_moe_experts_only(
256 &mut self,
257 ) -> (
258 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
259 &dyn DeviceMapper,
260 ) {
261 self.get_layers()
262 }
263
264 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
267 let layers = self
268 .get_layers()
269 .0
270 .into_iter()
271 .map(|(layer, _)| layer)
272 .collect::<Vec<_>>();
273 for layer in layers {
274 Arc::get_mut(layer).unwrap().begin_track_stats()?;
275 }
276 Ok(())
277 }
278
279 fn extract_imatrix_data_moe_experts_only(
282 &mut self,
283 ) -> candle_core::Result<CollectedImatrixData> {
284 let layers = self
285 .get_layers()
286 .0
287 .into_iter()
288 .enumerate()
289 .map(|(i, (layer, _))| (i, layer))
290 .collect::<Vec<_>>();
291 let mut data = HashMap::new();
292 for (i, layer) in layers {
293 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
294 }
295 Ok(CollectedImatrixData(data))
296 }
297
298 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
305 candle_core::bail!("This model does not support quantizing with an imatrix.");
307 }
308
309 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
311
312 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
314 None
315 }
316
317 #[allow(clippy::too_many_arguments)]
322 fn quantize(
323 &mut self,
324 dtype: Option<IsqType>,
325 device: Device,
326 topology: Option<&Topology>,
327 silent: bool,
328 imatrix_source: Option<ImatrixDataSource<'_>>,
329 organization: IsqOrganization,
330 apply_quantization: bool,
331 write_artifacts: Option<&PathBuf>,
332 full_ser: UqffFullSer<'_>,
333 multi_progress: Arc<MultiProgress>,
334 ) -> candle_core::Result<()> {
335 {
336 let mut imatrix_source = imatrix_source;
337 let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
338 if apply_quantization {
339 match imatrix_source.take() {
340 Some(ImatrixDataSource::File(imatrix)) => {
341 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
342 "Expected an extension for the imatrix source file.",
343 ))?;
344 if ext == "cimatrix" {
345 info!(
346 "Loading collected imatrix source file: `{}`",
347 imatrix.display()
348 );
349 let data = CollectedImatrixData::load_imatrix(imatrix)?;
350 info!(
351 "Quantizing with collected imatrix data, {} imatrix weights",
352 data.0.iter().filter(|(_, x)| x.is_some()).count()
353 );
354 Some(data.0)
355 } else {
356 if ext != "imatrix" {
357 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
358 }
359 info!(
360 "Loading GGUF-format imatrix source file: `{}`",
361 imatrix.display()
362 );
363 let mut imatrix_data =
364 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
365 let imatrix_mapping = self
366 .imatrix_names()?
367 .into_iter()
368 .enumerate()
369 .collect::<HashMap<_, _>>();
370
371 let layer_to_weight = imatrix_mapping
372 .into_iter()
373 .map(|(i, name)| {
374 if let Some(name) = name {
375 (i, Some(imatrix_data.remove(&name).unwrap()))
376 } else {
377 (i, None)
378 }
379 })
380 .collect::<HashMap<_, _>>();
381 info!(
382 "Quantizing with imatrix file `{}`, {} imatrix weights",
383 imatrix.display(),
384 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
385 );
386 Some(layer_to_weight)
387 }
388 }
389 Some(ImatrixDataSource::Collected) => {
390 let data = match organization {
391 IsqOrganization::Default => self.extract_imatrix_data()?,
392 IsqOrganization::MoeExpertsOnly => {
393 self.extract_imatrix_data_moe_experts_only()?
394 }
395 };
396 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
398 let save_path = format!("collected-{count}.cimatrix");
399 info!("Saving collected imatrix data to `{save_path}`");
400 data.save_imatrix(save_path)?;
401 info!(
402 "Quantizing with collected imatrix data, {count} imatrix weights"
403 );
404 Some(data.0)
405 }
406 None => None,
407 }
408 } else {
409 if imatrix_source.is_some() {
410 info!("Imatrix source provided but quantization disabled; ignoring input.");
411 }
412 None
413 };
414
415 let (mut tensors, mapper) = match organization {
416 IsqOrganization::Default => self.get_layers(),
417 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
418 };
419
420 let total_tensors = tensors.len();
421
422 if apply_quantization {
423 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
424 if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
425 let ordered_keys = imatrix_to_weight
426 .keys()
427 .copied()
428 .sorted()
429 .collect::<Vec<_>>();
430 ordered_keys
431 .into_iter()
432 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
433 .collect()
434 } else {
435 vec![None; tensors.len()]
436 };
437
438 let n_quantized = AtomicUsize::new(0);
439 if let Some(topology) = topology {
440 let mut dtypes = HashSet::new();
441 for layer in topology.layers.iter().flatten() {
442 if let LayerTopology {
443 isq: Some(isq_dtype),
444 device: _,
445 } = layer
446 {
447 dtypes.insert(isq_dtype);
448 }
449 }
450 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
451 } else {
452 info!(
453 "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
454 );
455 }
456 let bar = ProgressBar::new(total_tensors as u64);
457 configure_progress_bar(&bar);
458 bar.set_style(
459 ProgressStyle::default_bar()
460 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
461 .unwrap()
462 .progress_chars("#>-"),
463 );
464 multi_progress.add(bar.clone());
465
466 let layers = topology.map(|x| {
467 x.layers
468 .iter()
469 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
470 .collect::<Vec<_>>()
471 });
472
473 let mut devices_and_dtypes = Vec::new();
474 for (_, layer_num) in &tensors {
475 let device = if let Some(ref layers) = layers {
476 if let Some(layer) = layer_num {
477 layers
478 .get(*layer)
479 .as_ref()
480 .map(|x| x.1.clone())
481 .unwrap_or(Some(device.clone()))
482 .unwrap_or(device.clone())
483 } else {
484 device.clone()
485 }
486 } else if let Some(layer_num) = layer_num {
487 mapper
488 .device_for(*layer_num, false)
489 .cloned()
490 .unwrap_or(device.clone())
491 } else {
492 device.clone()
493 };
494 let dtype = if let Some(ref layers) = layers {
495 if let Some(layer) = layer_num {
496 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
497 } else {
498 dtype
499 }
500 } else {
501 dtype
502 };
503 devices_and_dtypes.push((device, dtype));
504 }
505
506 let t_start = Instant::now();
507
508 let mut minimum_max_threads = {
510 let current_rayon_threads = rayon::current_num_threads();
511 if let Some(dtype) = dtype {
512 dtype
513 .get_max_isq_cpu_threads()
514 .map(usize::from)
515 .unwrap_or(current_rayon_threads)
516 } else {
517 current_rayon_threads
518 }
519 };
520 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
521 minimum_max_threads = 1;
522 }
523
524 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
525 minimum_max_threads = 1;
527 }
528
529 info!("Applying ISQ on {minimum_max_threads} threads.");
530
531 let pool = rayon::ThreadPoolBuilder::new()
532 .num_threads(minimum_max_threads)
533 .build()
534 .map_err(candle_core::Error::msg)?;
535
536 let guard = QuantizeOntoGuard::new();
537
538 pool.install(|| {
539 use indicatif::ParallelProgressIterator;
540 use rayon::iter::{
541 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
542 };
543 if silent {
544 tensors
545 .par_iter_mut()
546 .zip(devices_and_dtypes)
547 .zip(imatrix_to_weight)
548 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
549 **tensor = tensor
550 .clone()
551 .apply_isq(
552 dtype,
553 device.clone(),
554 &n_quantized,
555 imatrix_weight,
556 guard.clone(),
557 )
558 .unwrap();
559 device.synchronize().unwrap();
560 });
561 } else {
562 tensors
563 .par_iter_mut()
564 .zip(devices_and_dtypes)
565 .zip(imatrix_to_weight)
566 .progress_with(bar)
567 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
568 **tensor = tensor
569 .clone()
570 .apply_isq(
571 dtype,
572 device.clone(),
573 &n_quantized,
574 imatrix_weight,
575 guard.clone(),
576 )
577 .unwrap();
578 device.synchronize().unwrap();
579 });
580 }
581 });
582
583 let t_end = Instant::now();
584 info!(
585 "Finished quantization pass in {:.2}s ({} tensors).",
586 t_end.duration_since(t_start).as_secs_f32(),
587 total_tensors
588 );
589 } else if imatrix_source.is_some() {
590 info!(
591 "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
592 );
593 } else if write_artifacts.is_some() {
594 info!(
595 "Skipping additional quantization; serializing {total_tensors} existing tensors."
596 );
597 }
598
599 if let Some(serialized) = write_artifacts {
600 info!(
601 "Serializing {total_tensors} ISQ tensors to `{}`.",
602 serialized.display()
603 );
604
605 if serialized.extension().is_none_or(|ext| ext != "uqff") {
606 candle_core::bail!("UQFF output path extension must be `.uqff`",);
607 }
608
609 let bar = ProgressBar::new(total_tensors as u64);
610 configure_progress_bar(&bar);
611 bar.set_style(
612 ProgressStyle::default_bar()
613 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
614 .unwrap()
615 .progress_chars("#>-"),
616 );
617
618 #[cfg(not(feature = "metal"))]
619 let n_threads = 2;
620 #[cfg(feature = "metal")]
621 let n_threads = 1;
622
623 let pool = rayon::ThreadPoolBuilder::new()
624 .num_threads(n_threads)
625 .build()
626 .map_err(candle_core::Error::msg)?;
627
628 let quantized_values = pool.install(|| {
629 use rayon::iter::IntoParallelRefIterator;
630 if silent {
631 tensors
632 .par_iter()
633 .enumerate()
634 .filter(|(_, (layer, _))| layer.isq_serde_supported())
635 .map(|(i, (layer, _))| {
636 Ok((
637 i.to_string(),
638 match layer.serialize()? {
639 Cow::Borrowed(_) => unreachable!(),
640 Cow::Owned(owned) => owned,
641 },
642 ))
643 })
644 .collect::<candle_core::Result<Vec<_>>>()
645 } else {
646 tensors
647 .par_iter()
648 .enumerate()
649 .progress_with(bar)
650 .filter(|(_, (layer, _))| layer.isq_serde_supported())
651 .map(|(i, (layer, _))| {
652 Ok((
653 i.to_string(),
654 match layer.serialize()? {
655 Cow::Borrowed(_) => unreachable!(),
656 Cow::Owned(owned) => owned,
657 },
658 ))
659 })
660 .collect::<candle_core::Result<Vec<_>>>()
661 }
662 });
663 let quantized_values = quantized_values?;
664
665 let parent = serialized
666 .parent()
667 .context("Target UQFF path must have a filename!")?;
668
669 std::fs::create_dir_all(parent)?;
670
671 let file_stem = serialized
672 .file_stem()
673 .context("Target UQFF path must have a file stem!")?
674 .to_string_lossy()
675 .to_string();
676
677 let mut current_chunk = Vec::new();
679 let mut current_bytes: usize = 0;
680 let mut shard_index = 0;
681
682 for (name, tensor) in quantized_values.iter() {
684 let tensor_bytes = tensor.len();
685 if !current_chunk.is_empty()
686 && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
687 {
688 let mut shard_path = parent.to_path_buf();
689 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
690 info!(
691 "Writing shard {} to `{}`",
692 shard_index,
693 shard_path.display()
694 );
695 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
696 shard_index += 1;
697 current_chunk.clear();
698 current_bytes = 0;
699 }
700 current_bytes += tensor_bytes;
701 current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
702 }
703
704 if !current_chunk.is_empty() {
705 let mut shard_path = parent.to_path_buf();
706 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
707 info!(
708 "Writing final shard {} to `{}`",
709 shard_index,
710 shard_path.display()
711 );
712 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
713 }
714
715 let residual = match organization {
716 IsqOrganization::Default => self.residual_tensors(),
717 IsqOrganization::MoeExpertsOnly => self
718 .residual_tensors_moe_experts_only()
719 .unwrap_or(self.residual_tensors()),
720 };
721
722 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
723 let config_out = parent.join("config.json");
724 let modules_out = parent.join("modules.json");
725 let tokenizer_out = parent.join("tokenizer.json");
726 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
727 let chat_template_jinja_out = parent.join("chat_template.jinja");
728 let gen_cfg_out = parent.join("generation_config.json");
729 let processor_out = parent.join("processor_config.json");
730 let preprocessor_out = parent.join("preprocessor_config.json");
731
732 info!(
733 "Serializing {} residual tensors to `{}`.",
734 residual.len(),
735 residual_out.display()
736 );
737
738 safetensors::serialize_to_file(residual, None, &residual_out)?;
739
740 let UqffFullSer {
741 tokenizer,
742 template_filename,
743 modules,
744 module_paths,
745 generation_config,
746 config,
747 processor_filename,
748 preprocessor_filename,
749 } = full_ser;
750
751 info!("Serializing configuration to `{}`.", config_out.display());
752
753 std::fs::write(config_out, config)?;
754
755 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
756
757 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
758 .map_err(candle_core::Error::msg)?;
759
760 if let Some(template_filename) = template_filename {
761 let template =
762 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
763
764 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
765 info!(
766 "Serializing chat template to `{}`.",
767 chat_template_jinja_out.display()
768 );
769 std::fs::write(&chat_template_jinja_out, template)
770 .map_err(candle_core::Error::msg)?;
771 } else {
772 info!(
773 "Serializing tokenizer config to `{}`.",
774 tokenizer_cfg_out.display()
775 );
776 std::fs::write(&tokenizer_cfg_out, template)
777 .map_err(candle_core::Error::msg)?;
778 }
779 }
780
781 if let Some(generation_config) = generation_config {
782 info!(
783 "Serializing generation config to `{}`.",
784 gen_cfg_out.display()
785 );
786
787 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
788 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
789 }
790
791 if let Some(processor_config) = processor_filename {
792 info!(
793 "Serializing processor config to `{}`.",
794 processor_out.display()
795 );
796
797 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
798 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
799 }
800
801 if let Some(preprocessor_config) = preprocessor_filename {
802 info!(
803 "Serializing preprocessor config to `{}`.",
804 preprocessor_out.display()
805 );
806
807 let cfg =
808 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
809 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
810 }
811
812 if let Some(modules) = modules {
813 info!(
814 "Serializing modules manifest to `{}`.",
815 modules_out.display()
816 );
817
818 std::fs::write(&modules_out, modules).map_err(candle_core::Error::msg)?;
819
820 if let Some(module_paths) = module_paths {
821 for module in module_paths {
822 match module {
823 EmbeddingModulePaths::Transformer { path }
824 | EmbeddingModulePaths::Pooling { path, .. }
825 | EmbeddingModulePaths::Dense { path, .. }
826 | EmbeddingModulePaths::Normalize { path } => {
827 if path.is_empty() {
828 continue;
829 }
830 let module_dir = parent.join(path.as_str());
831 std::fs::create_dir_all(&module_dir)
832 .map_err(candle_core::Error::msg)?;
833
834 match module {
835 EmbeddingModulePaths::Pooling { config, .. } => {
836 let dest = module_dir.join("config.json");
837 if config != &dest {
838 std::fs::copy(config, &dest)
839 .map_err(candle_core::Error::msg)?;
840 }
841 }
842 EmbeddingModulePaths::Dense { config, model, .. } => {
843 let dest_cfg = module_dir.join("config.json");
844 if config != &dest_cfg {
845 std::fs::copy(config, &dest_cfg)
846 .map_err(candle_core::Error::msg)?;
847 }
848 let dest_model = module_dir.join("model.safetensors");
849 if model != &dest_model {
850 std::fs::copy(model, &dest_model)
851 .map_err(candle_core::Error::msg)?;
852 }
853 }
854 EmbeddingModulePaths::Transformer { .. }
855 | EmbeddingModulePaths::Normalize { .. } => {}
856 }
857 }
858 }
859 }
860 }
861 }
862 }
863 }
864 Ok(())
865 }
866
867 fn load_from_artifacts(
868 &mut self,
869 device: Device,
870 topology: Option<&Topology>,
871 silent: bool,
872 artifacts: &[PathBuf],
873 ) -> candle_core::Result<()> {
874 let (tensors, mapper) = self.get_layers();
875 let total_tensors = tensors.len();
876
877 let layers = topology.map(|x| {
878 x.layers
879 .iter()
880 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
881 .collect::<Vec<_>>()
882 });
883
884 let mut devices = Vec::new();
885 let mut comms = Vec::new();
886 for (_, layer_num) in &tensors {
887 let device = if let Some(ref layers) = layers {
888 if let Some(layer) = layer_num {
889 layers
890 .get(*layer)
891 .as_ref()
892 .map(|x| x.1.clone())
893 .unwrap_or(Some(device.clone()))
894 .unwrap_or(device.clone())
895 } else {
896 device.clone()
897 }
898 } else if let Some(layer_num) = layer_num {
899 mapper
900 .device_for(*layer_num, false)
901 .cloned()
902 .unwrap_or(device.clone())
903 } else {
904 device.clone()
905 };
906 devices.push(device);
907 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
908 }
909
910 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
911
912 let artifact_isqs = artifacts
913 .tensors()
914 .into_iter()
915 .map(|(name, tensor)| {
916 (
917 name.parse::<usize>()
918 .expect("Name should be parseable as usize"),
919 tensor,
920 )
921 })
922 .collect::<HashMap<_, _>>();
923
924 if artifact_isqs.len() != total_tensors {
925 candle_core::bail!(
926 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
927 artifact_isqs.len(),
928 );
929 }
930
931 let bar = ProgressBar::new(total_tensors as u64);
932 configure_progress_bar(&bar);
933 bar.set_style(
934 ProgressStyle::default_bar()
935 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
936 .unwrap()
937 .progress_chars("#>-"),
938 );
939
940 let t_start = Instant::now();
941
942 let guard = QuantizeOntoGuard::new();
943
944 if silent {
945 (0..tensors.len())
946 .into_par_iter()
947 .zip(tensors)
948 .map(|(i, (tensor, _))| {
949 if let Some(artifact) = artifact_isqs.get(&i) {
950 let artifact = artifact.data();
951
952 let comm = comms[i].clone();
953 let deserialized = match tensor.is_distributed() {
954 Some(DistributedKind::ColumnParallel) => {
955 ColumnParallelLayer::deserialize(
956 Cow::from(artifact),
957 &devices[i],
958 &comm,
959 guard.clone(),
960 )?
961 }
962 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
963 Cow::from(artifact),
964 &devices[i],
965 &comm,
966 guard.clone(),
967 )?,
968 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
969 Cow::from(artifact),
970 &devices[i],
971 &comm,
972 guard.clone(),
973 )?,
974 None => {
975 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
977 match QuantizedSerdeType::try_from(isq_type as usize)? {
978 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
979 Cow::from(artifact),
980 &devices[i],
981 &comm,
982 guard.clone(),
983 )?,
984 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
985 Cow::from(artifact),
986 &devices[i],
987 &comm,
988 guard.clone(),
989 )?,
990 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
991 Cow::from(artifact),
992 &devices[i],
993 &comm,
994 guard.clone(),
995 )?,
996 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
997 Cow::from(artifact),
998 &devices[i],
999 &comm,
1000 guard.clone(),
1001 )?,
1002 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1003 Cow::from(artifact),
1004 &devices[i],
1005 &comm,
1006 guard.clone(),
1007 )?,
1008 }
1009 }
1010 };
1011 *tensor = deserialized;
1012 }
1013 Ok(())
1014 })
1015 .collect::<candle_core::Result<Vec<_>>>()?;
1016 } else {
1017 (0..tensors.len())
1018 .into_par_iter()
1019 .zip(tensors)
1020 .progress_with(bar)
1021 .map(|(i, (tensor, _))| {
1022 if let Some(artifact) = artifact_isqs.get(&i) {
1023 let artifact = artifact.data();
1024
1025 let comm = comms[i].clone();
1026 let deserialized = match tensor.is_distributed() {
1027 Some(DistributedKind::ColumnParallel) => {
1028 ColumnParallelLayer::deserialize(
1029 Cow::from(artifact),
1030 &devices[i],
1031 &comm,
1032 guard.clone(),
1033 )?
1034 }
1035 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1036 Cow::from(artifact),
1037 &devices[i],
1038 &comm,
1039 guard.clone(),
1040 )?,
1041 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1042 Cow::from(artifact),
1043 &devices[i],
1044 &comm,
1045 guard.clone(),
1046 )?,
1047 None => {
1048 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
1050 match QuantizedSerdeType::try_from(isq_type as usize)? {
1051 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1052 Cow::from(artifact),
1053 &devices[i],
1054 &comm,
1055 guard.clone(),
1056 )?,
1057 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1058 Cow::from(artifact),
1059 &devices[i],
1060 &comm,
1061 guard.clone(),
1062 )?,
1063 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1064 Cow::from(artifact),
1065 &devices[i],
1066 &comm,
1067 guard.clone(),
1068 )?,
1069 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1070 Cow::from(artifact),
1071 &devices[i],
1072 &comm,
1073 guard.clone(),
1074 )?,
1075 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1076 Cow::from(artifact),
1077 &devices[i],
1078 &comm,
1079 guard.clone(),
1080 )?,
1081 }
1082 }
1083 };
1084 *tensor = deserialized;
1085 }
1086 Ok(())
1087 })
1088 .collect::<candle_core::Result<Vec<_>>>()?;
1089 }
1090
1091 let delta = Instant::now().duration_since(t_start).as_secs_f32();
1092 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1093
1094 Ok(())
1095 }
1096}
1097
1098pub(crate) trait IsqModelLoader {
1100 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1104 Ok(Vec::new())
1105 }
1106
1107 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1111 self.isq_layer_regexes(config)
1112 }
1113
1114 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1118 Ok(Vec::new())
1119 }
1120
1121 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1125 self.isq_layer_regexes(config)
1126 }
1127}