mistralrs_core/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11
12use anyhow::Result;
13use candle_core::{quantized, Context, Device, Tensor};
14use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
15use itertools::Itertools;
16use mistralrs_quant::{
17    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
18    HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
19    ReplicatedLayer, RowParallelLayer, UnquantLinear,
20};
21use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
22use regex::Regex;
23use serde::Deserialize;
24use tokenizers::Tokenizer;
25use tracing::{info, warn};
26
27use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
28
29pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
30// 10 GB max per file
31const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
32pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
33
34/// Parse ISQ value.
35///
36/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
37/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
38///
39/// One of:
40/// - `Q4_0`
41/// - `Q4_1`
42/// - `Q5_0`
43/// - `Q5_1`
44/// - `Q8_0`
45/// - `Q8_1`
46/// - `Q2K`
47/// - `Q3K`
48/// - `Q4K`
49/// - `Q5K`
50/// - `Q6K`
51/// - `Q8K`
52/// - `HQQ1`
53/// - `HQQ2`
54/// - `HQQ3`
55/// - `HQQ4`
56/// - `HQQ8`
57/// - `AFQ2`
58/// - `AFQ3`
59/// - `AFQ4`
60/// - `AFQ6`
61/// - `AFQ8`
62pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
63    let tp = match s.to_lowercase().as_str() {
64        "2" if cfg!(feature = "metal") => IsqType::AFQ2,
65        "2" if !cfg!(feature = "metal") => IsqType::Q2K,
66        "3" if cfg!(feature = "metal") => IsqType::AFQ3,
67        "3" if !cfg!(feature = "metal") => IsqType::Q3K,
68        "4" if cfg!(feature = "metal") => IsqType::AFQ4,
69        "4" if !cfg!(feature = "metal") => IsqType::Q4K,
70        "5" => IsqType::Q5K,
71        "6" if cfg!(feature = "metal") => IsqType::AFQ6,
72        "6" if !cfg!(feature = "metal") => IsqType::Q6K,
73        "8" if cfg!(feature = "metal") => IsqType::AFQ8,
74        "8" if !cfg!(feature = "metal") => IsqType::Q8_0,
75        "q4_0" => IsqType::Q4_0,
76        "q4_1" => IsqType::Q4_1,
77        "q5_0" => IsqType::Q5_0,
78        "q5_1" => IsqType::Q5_1,
79        "q8_0" => IsqType::Q8_0,
80        "q8_1" => IsqType::Q8_1,
81        "q2k" => IsqType::Q2K,
82        "q3k" => IsqType::Q3K,
83        "q4k" => IsqType::Q4K,
84        "q5k" => IsqType::Q5K,
85        "q6k" => IsqType::Q6K,
86        "q8k" => IsqType::Q8K,
87        "hqq8" => IsqType::HQQ8,
88        "hqq4" => IsqType::HQQ4,
89        "fp8" => IsqType::F8E4M3,
90        "afq8" => IsqType::AFQ8,
91        "afq6" => IsqType::AFQ6,
92        "afq4" => IsqType::AFQ4,
93        "afq3" => IsqType::AFQ3,
94        "afq2" => IsqType::AFQ2,
95        // "hqq3" => IsqType::HQQ3,
96        // "hqq2" => IsqType::HQQ2,
97        // "hqq1" => IsqType::HQQ1,
98        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
99    };
100    #[cfg(feature = "cuda")]
101    {
102        if !matches!(
103            tp,
104            IsqType::Q4_0
105                | IsqType::Q4_1
106                | IsqType::Q5_0
107                | IsqType::Q5_1
108                | IsqType::Q8_0
109                | IsqType::Q2K
110                | IsqType::Q3K
111                | IsqType::Q4K
112                | IsqType::Q5K
113                | IsqType::Q6K
114                | IsqType::HQQ8
115                | IsqType::HQQ4
116                | IsqType::F8E4M3 // | IsqType::HQQ3
117                                  // | IsqType::HQQ2
118                                  // | IsqType::HQQ1
119        ) {
120            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
121        }
122    }
123    Ok(tp)
124}
125
126#[derive(Clone, Debug, Copy, Default, Deserialize)]
127pub enum IsqOrganization {
128    #[default]
129    #[serde(rename = "default")]
130    Default,
131    /// Only quantize MoE experts, if applicable. The enables MoQE.
132    /// <https://arxiv.org/abs/2310.02410>
133    #[serde(rename = "moqe")]
134    MoeExpertsOnly,
135}
136
137impl FromStr for IsqOrganization {
138    type Err = String;
139    fn from_str(s: &str) -> Result<Self, Self::Err> {
140        match s {
141            "default" => Ok(Self::Default),
142            "moqe" => Ok(Self::MoeExpertsOnly),
143            other => Err(format!(
144                "Expected ISQ organization `default` or `moqe`, got `{other}`"
145            )),
146        }
147    }
148}
149
150pub struct UqffFullSer<'a> {
151    pub tokenizer: &'a Tokenizer,
152    pub template_filename: &'a Option<PathBuf>,
153    pub generation_config: Option<&'a PathBuf>,
154    pub config: String,
155    pub processor_filename: &'a Option<PathBuf>,
156    pub preprocessor_filename: &'a Option<PathBuf>,
157}
158
159#[derive(Debug, Clone, Copy)]
160pub enum ImatrixDataSource<'a> {
161    File(&'a PathBuf),
162    Collected,
163}
164
165pub trait IsqModel {
166    /// Corresponds to `IsqOrganization::Default`
167    #[allow(clippy::type_complexity)]
168    fn get_layers(
169        &mut self,
170    ) -> (
171        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
172        &dyn DeviceMapper,
173    );
174
175    /// This is used for imatrix generation internally. Begin stats tracking.
176    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
177        let layers = self
178            .get_layers()
179            .0
180            .into_iter()
181            .map(|(layer, _)| layer)
182            .collect::<Vec<_>>();
183        for layer in layers {
184            Arc::get_mut(layer).unwrap().begin_track_stats()?;
185        }
186        Ok(())
187    }
188
189    /// End stats tracking and return the imatrix data
190    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
191        let layers = self
192            .get_layers()
193            .0
194            .into_iter()
195            .enumerate()
196            .map(|(i, (layer, _))| (i, layer))
197            .collect::<Vec<_>>();
198        let mut data = HashMap::new();
199        for (i, layer) in layers {
200            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
201        }
202        Ok(CollectedImatrixData(data))
203    }
204
205    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
206    /// https://arxiv.org/abs/2310.02410
207    #[allow(clippy::type_complexity)]
208    fn get_layers_moe_experts_only(
209        &mut self,
210    ) -> (
211        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
212        &dyn DeviceMapper,
213    ) {
214        self.get_layers()
215    }
216
217    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
218    /// This is used for imatrix generation internally. Begin stats tracking.
219    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
220        let layers = self
221            .get_layers()
222            .0
223            .into_iter()
224            .map(|(layer, _)| layer)
225            .collect::<Vec<_>>();
226        for layer in layers {
227            Arc::get_mut(layer).unwrap().begin_track_stats()?;
228        }
229        Ok(())
230    }
231
232    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
233    /// End stats tracking and return the imatrix data
234    fn extract_imatrix_data_moe_experts_only(
235        &mut self,
236    ) -> candle_core::Result<CollectedImatrixData> {
237        let layers = self
238            .get_layers()
239            .0
240            .into_iter()
241            .enumerate()
242            .map(|(i, (layer, _))| (i, layer))
243            .collect::<Vec<_>>();
244        let mut data = HashMap::new();
245        for (i, layer) in layers {
246            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
247        }
248        Ok(CollectedImatrixData(data))
249    }
250
251    /// Corresponding to the specific order the model produces ISQ layers (None means
252    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
253    /// corresponding imatrix weights.
254    ///
255    /// - This is only for loading from a llama.cpp imatrix file.
256    /// - Corresponds to `IsqOrganization::Default`
257    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
258        // TODO: make this required.
259        candle_core::bail!("This model does not support quantizing with an imatrix.");
260    }
261
262    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
263    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
264
265    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
266    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
267        None
268    }
269
270    /// Quantize the model in-situ.
271    ///
272    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
273    /// a full serialization is created.
274    #[allow(clippy::too_many_arguments)]
275    fn quantize(
276        &mut self,
277        dtype: Option<IsqType>,
278        device: Device,
279        topology: Option<&Topology>,
280        silent: bool,
281        imatrix_source: Option<ImatrixDataSource<'_>>,
282        organization: IsqOrganization,
283        write_artifacts: Option<&PathBuf>,
284        full_ser: UqffFullSer<'_>,
285        multi_progress: Arc<MultiProgress>,
286    ) -> candle_core::Result<()> {
287        {
288            let imatrix_to_weight = match imatrix_source {
289                Some(ImatrixDataSource::File(imatrix)) => {
290                    let ext = imatrix.extension().ok_or(candle_core::Error::msg(
291                        "Expected an extension for the imatrix source file.",
292                    ))?;
293                    if ext == "cimatrix" {
294                        info!(
295                            "Loading collected imatrix source file: `{}`",
296                            imatrix.display()
297                        );
298                        Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
299                    } else if ext == "imatrix" {
300                        info!(
301                            "Loading GGUF-format imatrix source file: `{}`",
302                            imatrix.display()
303                        );
304                        let mut imatrix_data =
305                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
306                        let imatrix_mapping = self
307                            .imatrix_names()?
308                            .into_iter()
309                            .enumerate()
310                            .collect::<HashMap<_, _>>();
311
312                        let layer_to_weight = imatrix_mapping
313                            .into_iter()
314                            .map(|(i, name)| {
315                                if let Some(name) = name {
316                                    (i, Some(imatrix_data.remove(&name).unwrap()))
317                                } else {
318                                    (i, None)
319                                }
320                            })
321                            .collect::<HashMap<_, _>>();
322                        info!(
323                            "Quantizing with imatrix file `{}`, {} imatrix weights",
324                            imatrix.display(),
325                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
326                        );
327                        Some(layer_to_weight)
328                    } else {
329                        warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
330                        info!(
331                            "Loading GGUF-format imatrix source file: `{}`",
332                            imatrix.display()
333                        );
334
335                        let mut imatrix_data =
336                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
337                        let imatrix_mapping = self
338                            .imatrix_names()?
339                            .into_iter()
340                            .enumerate()
341                            .collect::<HashMap<_, _>>();
342
343                        let layer_to_weight = imatrix_mapping
344                            .into_iter()
345                            .map(|(i, name)| {
346                                if let Some(name) = name {
347                                    (i, Some(imatrix_data.remove(&name).unwrap()))
348                                } else {
349                                    (i, None)
350                                }
351                            })
352                            .collect::<HashMap<_, _>>();
353                        info!(
354                            "Quantizing with imatrix file `{}`, {} imatrix weights",
355                            imatrix.display(),
356                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
357                        );
358                        Some(layer_to_weight)
359                    }
360                }
361                Some(ImatrixDataSource::Collected) => {
362                    let data = match organization {
363                        IsqOrganization::Default => self.extract_imatrix_data()?,
364                        IsqOrganization::MoeExpertsOnly => {
365                            self.extract_imatrix_data_moe_experts_only()?
366                        }
367                    };
368                    // Save the collected imatrix data so users can reuse it
369                    let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
370                    let save_path = format!("collected-{count}.cimatrix");
371                    info!("Saving collected imatrix data to `{save_path}`");
372                    data.save_imatrix(save_path)?;
373                    info!("Quantizing with collected imatrix data, {count} imatrix weights");
374                    Some(data.0)
375                }
376                None => {
377                    // Dummy, just for zip
378                    None
379                }
380            };
381
382            let (mut tensors, mapper) = match organization {
383                IsqOrganization::Default => self.get_layers(),
384                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
385            };
386
387            let imatrix_to_weight: Vec<Option<Vec<f32>>> =
388                if let Some(mut imatrix_to_weight) = imatrix_to_weight {
389                    let ordered_keys = imatrix_to_weight
390                        .keys()
391                        .copied()
392                        .sorted()
393                        .collect::<Vec<_>>();
394                    ordered_keys
395                        .into_iter()
396                        .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
397                        .collect()
398                } else {
399                    vec![None; tensors.len()]
400                };
401
402            let total_tensors = tensors.len();
403            let n_quantized = AtomicUsize::new(0);
404            if let Some(topology) = topology {
405                let mut dtypes = HashSet::new();
406                for layer in topology.0.iter().flatten() {
407                    if let LayerTopology {
408                        isq: Some(isq_dtype),
409                        device: _,
410                    } = layer
411                    {
412                        dtypes.insert(isq_dtype);
413                    }
414                }
415                info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
416            } else {
417                info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
418            }
419            let bar = ProgressBar::new(total_tensors as u64);
420            bar.set_style(
421                ProgressStyle::default_bar()
422                    .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
423                    .unwrap()
424                    .progress_chars("#>-"),
425            );
426            multi_progress.add(bar.clone());
427
428            let layers = topology.map(|x| {
429                x.0.iter()
430                    .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
431                    .collect::<Vec<_>>()
432            });
433
434            let mut devices_and_dtypes = Vec::new();
435            for (_, layer_num) in &tensors {
436                let device = if let Some(ref layers) = layers {
437                    if let Some(layer) = layer_num {
438                        layers
439                            .get(*layer)
440                            .as_ref()
441                            .map(|x| x.1.clone())
442                            .unwrap_or(Some(device.clone()))
443                            .unwrap_or(device.clone())
444                    } else {
445                        device.clone()
446                    }
447                } else if let Some(layer_num) = layer_num {
448                    mapper
449                        .device_for(*layer_num, false)
450                        .cloned()
451                        .unwrap_or(device.clone())
452                } else {
453                    device.clone()
454                };
455                let dtype = if let Some(ref layers) = layers {
456                    if let Some(layer) = layer_num {
457                        layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
458                    } else {
459                        dtype
460                    }
461                } else {
462                    dtype
463                };
464                devices_and_dtypes.push((device, dtype));
465            }
466
467            let t_start = Instant::now();
468
469            use rayon::iter::IntoParallelRefIterator;
470
471            // Get the MINIMUM of the max isq threads the quant method
472            let mut minimum_max_threads = {
473                let current_rayon_threads = rayon::current_num_threads();
474                if let Some(dtype) = dtype {
475                    dtype
476                        .get_max_isq_cpu_threads()
477                        .map(usize::from)
478                        .unwrap_or(current_rayon_threads)
479                } else {
480                    current_rayon_threads
481                }
482            };
483            if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
484                minimum_max_threads = 1;
485            }
486
487            if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
488                // Collected imatrix means that the model is potentially on the gpu already
489                minimum_max_threads = 1;
490            }
491
492            info!("Applying ISQ on {minimum_max_threads} threads.");
493
494            let pool = rayon::ThreadPoolBuilder::new()
495                .num_threads(minimum_max_threads)
496                .build()
497                .map_err(candle_core::Error::msg)?;
498
499            let guard = QuantizeOntoGuard::new();
500
501            pool.install(|| {
502                use indicatif::ParallelProgressIterator;
503                use rayon::iter::{
504                    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
505                };
506                if silent {
507                    tensors
508                        .par_iter_mut()
509                        .zip(devices_and_dtypes)
510                        .zip(imatrix_to_weight)
511                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
512                            **tensor = tensor
513                                .clone()
514                                .apply_isq(
515                                    dtype,
516                                    device.clone(),
517                                    &n_quantized,
518                                    imatrix_weight,
519                                    guard.clone(),
520                                )
521                                .unwrap();
522                            device.synchronize().unwrap();
523                        });
524                } else {
525                    tensors
526                        .par_iter_mut()
527                        .zip(devices_and_dtypes)
528                        .zip(imatrix_to_weight)
529                        .progress_with(bar)
530                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
531                            **tensor = tensor
532                                .clone()
533                                .apply_isq(
534                                    dtype,
535                                    device.clone(),
536                                    &n_quantized,
537                                    imatrix_weight,
538                                    guard.clone(),
539                                )
540                                .unwrap();
541                            device.synchronize().unwrap();
542                        });
543                }
544            });
545
546            if let Some(serialized) = write_artifacts {
547                info!(
548                    "Serializing {total_tensors} ISQ tensors to `{}`.",
549                    serialized.display()
550                );
551
552                if serialized.extension().is_none_or(|ext| ext != "uqff") {
553                    candle_core::bail!("UQFF output path extension must be `.uqff`",);
554                }
555
556                let bar = ProgressBar::new(total_tensors as u64);
557                bar.set_style(
558                    ProgressStyle::default_bar()
559                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
560                        .unwrap()
561                        .progress_chars("#>-"),
562                );
563
564                #[cfg(not(feature = "metal"))]
565                let n_threads = 2;
566                #[cfg(feature = "metal")]
567                let n_threads = 1;
568
569                let pool = rayon::ThreadPoolBuilder::new()
570                    .num_threads(n_threads)
571                    .build()
572                    .map_err(candle_core::Error::msg)?;
573
574                let quantized_values = pool.install(|| {
575                    if silent {
576                        tensors
577                            .par_iter()
578                            .enumerate()
579                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
580                            .map(|(i, (layer, _))| {
581                                Ok((
582                                    i.to_string(),
583                                    Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
584                                ))
585                            })
586                            .collect::<candle_core::Result<Vec<_>>>()
587                    } else {
588                        tensors
589                            .par_iter()
590                            .enumerate()
591                            .progress_with(bar)
592                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
593                            .map(|(i, (layer, _))| {
594                                Ok((
595                                    i.to_string(),
596                                    Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
597                                ))
598                            })
599                            .collect::<candle_core::Result<Vec<_>>>()
600                    }
601                });
602                let quantized_values = quantized_values?;
603
604                let parent = serialized
605                    .parent()
606                    .context("Target UQFF path must have a filename!")?;
607
608                std::fs::create_dir_all(parent)?;
609
610                let file_stem = serialized
611                    .file_stem()
612                    .context("Target UQFF path must have a file stem!")?
613                    .to_string_lossy()
614                    .to_string();
615
616                let size_estimate_bytes = quantized_values
617                    .iter()
618                    .map(|(_, x)| x.elem_count() * x.dtype().size_in_bytes())
619                    .sum::<usize>();
620                let n_files = size_estimate_bytes.div_ceil(MAX_UQFF_SIZE_BYTES);
621
622                if n_files == 1 {
623                    info!("Writing to `{}`", serialized.display());
624                    safetensors::serialize_to_file(quantized_values, &None, serialized)?;
625                } else {
626                    let chunksize = quantized_values.len() / n_files;
627                    let quantized_values_chunks = quantized_values.into_iter().chunks(chunksize);
628                    for (i, chunk) in quantized_values_chunks.into_iter().enumerate() {
629                        let mut name = parent.to_path_buf();
630                        name.push(format!("{file_stem}-{i}.uqff"));
631                        info!("Writing shard {i} to `{}`", name.display());
632                        safetensors::serialize_to_file(chunk, &None, &name)?;
633                    }
634                }
635
636                let residual = match organization {
637                    IsqOrganization::Default => self.residual_tensors(),
638                    IsqOrganization::MoeExpertsOnly => self
639                        .residual_tensors_moe_experts_only()
640                        .unwrap_or(self.residual_tensors()),
641                };
642
643                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
644                let config_out = parent.join("config.json");
645                let tokenizer_out = parent.join("tokenizer.json");
646                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
647                let gen_cfg_out = parent.join("generation_config.json");
648                let processor_out = parent.join("processor_config.json");
649                let preprocessor_out = parent.join("preprocessor_config.json");
650
651                info!(
652                    "Serializing {} residual tensors to `{}`.",
653                    residual.len(),
654                    residual_out.display()
655                );
656
657                safetensors::serialize_to_file(residual, &None, &residual_out)?;
658
659                let UqffFullSer {
660                    tokenizer,
661                    template_filename,
662                    generation_config,
663                    config,
664                    processor_filename,
665                    preprocessor_filename,
666                } = full_ser;
667
668                info!("Serializing configuration to `{}`.", config_out.display());
669
670                std::fs::write(config_out, config)?;
671
672                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
673
674                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
675                    .map_err(candle_core::Error::msg)?;
676
677                if let Some(template_filename) = template_filename {
678                    info!(
679                        "Serializing tokenizer config to `{}`.",
680                        tokenizer_cfg_out.display()
681                    );
682
683                    let template =
684                        std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
685                    std::fs::write(&tokenizer_cfg_out, template)
686                        .map_err(candle_core::Error::msg)?;
687                }
688
689                if let Some(generation_config) = generation_config {
690                    info!(
691                        "Serializing generation config to `{}`.",
692                        gen_cfg_out.display()
693                    );
694
695                    let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
696                    std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
697                }
698
699                if let Some(processor_config) = processor_filename {
700                    info!(
701                        "Serializing processor config to `{}`.",
702                        processor_out.display()
703                    );
704
705                    let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
706                    std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
707                }
708
709                if let Some(preprocessor_config) = preprocessor_filename {
710                    info!(
711                        "Serializing preprocessor config to `{}`.",
712                        preprocessor_out.display()
713                    );
714
715                    let cfg =
716                        std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
717                    std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
718                }
719            }
720            let delta = Instant::now().duration_since(t_start).as_secs_f32();
721            info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
722        }
723        Ok(())
724    }
725
726    fn load_from_artifacts(
727        &mut self,
728        device: Device,
729        topology: Option<&Topology>,
730        silent: bool,
731        artifacts: &[PathBuf],
732    ) -> candle_core::Result<()> {
733        let (tensors, mapper) = self.get_layers();
734        let total_tensors = tensors.len();
735
736        let layers = topology.map(|x| {
737            x.0.iter()
738                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
739                .collect::<Vec<_>>()
740        });
741
742        let mut devices = Vec::new();
743        let mut comms = Vec::new();
744        for (_, layer_num) in &tensors {
745            let device = if let Some(ref layers) = layers {
746                if let Some(layer) = layer_num {
747                    layers
748                        .get(*layer)
749                        .as_ref()
750                        .map(|x| x.1.clone())
751                        .unwrap_or(Some(device.clone()))
752                        .unwrap_or(device.clone())
753                } else {
754                    device.clone()
755                }
756            } else if let Some(layer_num) = layer_num {
757                mapper
758                    .device_for(*layer_num, false)
759                    .cloned()
760                    .unwrap_or(device.clone())
761            } else {
762                device.clone()
763            };
764            devices.push(device);
765            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
766        }
767
768        let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
769
770        let artifact_isqs = artifacts
771            .tensors()
772            .into_iter()
773            .map(|(name, tensor)| {
774                (
775                    name.parse::<usize>()
776                        .expect("Name should be parseable as usize"),
777                    tensor,
778                )
779            })
780            .collect::<HashMap<_, _>>();
781
782        if artifact_isqs.len() != total_tensors {
783            candle_core::bail!(
784                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
785                artifact_isqs.len(),
786            );
787        }
788
789        let bar = ProgressBar::new(total_tensors as u64);
790        bar.set_style(
791            ProgressStyle::default_bar()
792                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
793                .unwrap()
794                .progress_chars("#>-"),
795        );
796
797        let t_start = Instant::now();
798
799        let guard = QuantizeOntoGuard::new();
800
801        if silent {
802            (0..tensors.len())
803                .into_par_iter()
804                .zip(tensors)
805                .map(|(i, (tensor, _))| {
806                    if let Some(artifact) = artifact_isqs.get(&i) {
807                        let artifact = artifact.data();
808
809                        let comm = comms[i].clone();
810                        let deserialized = match tensor.is_distributed() {
811                            Some(DistributedKind::ColumnParallel) => {
812                                ColumnParallelLayer::deserialize(
813                                    Cow::from(artifact),
814                                    &devices[i],
815                                    &comm,
816                                    guard.clone(),
817                                )?
818                            }
819                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
820                                Cow::from(artifact),
821                                &devices[i],
822                                &comm,
823                                guard.clone(),
824                            )?,
825                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
826                                Cow::from(artifact),
827                                &devices[i],
828                                &comm,
829                                guard.clone(),
830                            )?,
831                            None => {
832                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
833                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
834                                match QuantizedSerdeType::try_from(isq_type as usize)? {
835                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
836                                        Cow::from(artifact),
837                                        &devices[i],
838                                        &comm,
839                                        guard.clone(),
840                                    )?,
841                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
842                                        Cow::from(artifact),
843                                        &devices[i],
844                                        &comm,
845                                        guard.clone(),
846                                    )?,
847                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
848                                        Cow::from(artifact),
849                                        &devices[i],
850                                        &comm,
851                                        guard.clone(),
852                                    )?,
853                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
854                                        Cow::from(artifact),
855                                        &devices[i],
856                                        &comm,
857                                        guard.clone(),
858                                    )?,
859                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
860                                        Cow::from(artifact),
861                                        &devices[i],
862                                        &comm,
863                                        guard.clone(),
864                                    )?,
865                                }
866                            }
867                        };
868                        *tensor = deserialized;
869                    }
870                    Ok(())
871                })
872                .collect::<candle_core::Result<Vec<_>>>()?;
873        } else {
874            (0..tensors.len())
875                .into_par_iter()
876                .zip(tensors)
877                .progress_with(bar)
878                .map(|(i, (tensor, _))| {
879                    if let Some(artifact) = artifact_isqs.get(&i) {
880                        let artifact = artifact.data();
881
882                        let comm = comms[i].clone();
883                        let deserialized = match tensor.is_distributed() {
884                            Some(DistributedKind::ColumnParallel) => {
885                                ColumnParallelLayer::deserialize(
886                                    Cow::from(artifact),
887                                    &devices[i],
888                                    &comm,
889                                    guard.clone(),
890                                )?
891                            }
892                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
893                                Cow::from(artifact),
894                                &devices[i],
895                                &comm,
896                                guard.clone(),
897                            )?,
898                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
899                                Cow::from(artifact),
900                                &devices[i],
901                                &comm,
902                                guard.clone(),
903                            )?,
904                            None => {
905                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
906                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
907                                match QuantizedSerdeType::try_from(isq_type as usize)? {
908                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
909                                        Cow::from(artifact),
910                                        &devices[i],
911                                        &comm,
912                                        guard.clone(),
913                                    )?,
914                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
915                                        Cow::from(artifact),
916                                        &devices[i],
917                                        &comm,
918                                        guard.clone(),
919                                    )?,
920                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
921                                        Cow::from(artifact),
922                                        &devices[i],
923                                        &comm,
924                                        guard.clone(),
925                                    )?,
926                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
927                                        Cow::from(artifact),
928                                        &devices[i],
929                                        &comm,
930                                        guard.clone(),
931                                    )?,
932                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
933                                        Cow::from(artifact),
934                                        &devices[i],
935                                        &comm,
936                                        guard.clone(),
937                                    )?,
938                                }
939                            }
940                        };
941                        *tensor = deserialized;
942                    }
943                    Ok(())
944                })
945                .collect::<candle_core::Result<Vec<_>>>()?;
946        }
947
948        let delta = Instant::now().duration_since(t_start).as_secs_f32();
949        info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
950
951        Ok(())
952    }
953}
954
955/// Trait for loading models with ISQ.
956pub(crate) trait IsqModelLoader {
957    /// Regex to match layers which will have standard ISQ applied.
958    ///
959    /// Only called on non-adapter models!
960    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
961        Ok(Vec::new())
962    }
963
964    /// Regex to match layers which will have standard MoQE ISQ applied.
965    ///
966    /// Only called on non-adapter models!
967    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
968        self.isq_layer_regexes(config)
969    }
970}