mistralrs_core/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11
12use anyhow::Result;
13use candle_core::{quantized, Context, Device, Tensor};
14use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
15use itertools::Itertools;
16use mistralrs_quant::{
17    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
18    HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
19    ReplicatedLayer, RowParallelLayer, UnquantLinear,
20};
21use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
22use regex::Regex;
23use serde::Deserialize;
24use tokenizers::Tokenizer;
25use tracing::{info, warn};
26
27use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
28
29pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
30
31/// Parse ISQ value.
32///
33/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
34/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
35///
36/// One of:
37/// - `Q4_0`
38/// - `Q4_1`
39/// - `Q5_0`
40/// - `Q5_1`
41/// - `Q8_0`
42/// - `Q8_1`
43/// - `Q2K`
44/// - `Q3K`
45/// - `Q4K`
46/// - `Q5K`
47/// - `Q6K`
48/// - `Q8K`
49/// - `HQQ1`
50/// - `HQQ2`
51/// - `HQQ3`
52/// - `HQQ4`
53/// - `HQQ8`
54/// - `AFQ2`
55/// - `AFQ3`
56/// - `AFQ4`
57/// - `AFQ6`
58/// - `AFQ8`
59pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
60    let tp = match s.to_lowercase().as_str() {
61        "2" if cfg!(feature = "metal") => IsqType::AFQ2,
62        "2" if !cfg!(feature = "metal") => IsqType::Q2K,
63        "3" if cfg!(feature = "metal") => IsqType::AFQ3,
64        "3" if !cfg!(feature = "metal") => IsqType::Q3K,
65        "4" if cfg!(feature = "metal") => IsqType::AFQ4,
66        "4" if !cfg!(feature = "metal") => IsqType::Q4K,
67        "5" => IsqType::Q5K,
68        "6" if cfg!(feature = "metal") => IsqType::AFQ6,
69        "6" if !cfg!(feature = "metal") => IsqType::Q6K,
70        "8" if cfg!(feature = "metal") => IsqType::AFQ8,
71        "8" if !cfg!(feature = "metal") => IsqType::Q8_0,
72        "q4_0" => IsqType::Q4_0,
73        "q4_1" => IsqType::Q4_1,
74        "q5_0" => IsqType::Q5_0,
75        "q5_1" => IsqType::Q5_1,
76        "q8_0" => IsqType::Q8_0,
77        "q8_1" => IsqType::Q8_1,
78        "q2k" => IsqType::Q2K,
79        "q3k" => IsqType::Q3K,
80        "q4k" => IsqType::Q4K,
81        "q5k" => IsqType::Q5K,
82        "q6k" => IsqType::Q6K,
83        "q8k" => IsqType::Q8K,
84        "hqq8" => IsqType::HQQ8,
85        "hqq4" => IsqType::HQQ4,
86        "fp8" => IsqType::F8E4M3,
87        "afq8" => IsqType::AFQ8,
88        "afq6" => IsqType::AFQ6,
89        "afq4" => IsqType::AFQ4,
90        "afq3" => IsqType::AFQ3,
91        "afq2" => IsqType::AFQ2,
92        // "hqq3" => IsqType::HQQ3,
93        // "hqq2" => IsqType::HQQ2,
94        // "hqq1" => IsqType::HQQ1,
95        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
96    };
97    #[cfg(feature = "cuda")]
98    {
99        if !matches!(
100            tp,
101            IsqType::Q4_0
102                | IsqType::Q4_1
103                | IsqType::Q5_0
104                | IsqType::Q5_1
105                | IsqType::Q8_0
106                | IsqType::Q2K
107                | IsqType::Q3K
108                | IsqType::Q4K
109                | IsqType::Q5K
110                | IsqType::Q6K
111                | IsqType::HQQ8
112                | IsqType::HQQ4
113                | IsqType::F8E4M3 // | IsqType::HQQ3
114                                  // | IsqType::HQQ2
115                                  // | IsqType::HQQ1
116        ) {
117            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
118        }
119    }
120    Ok(tp)
121}
122
123#[derive(Clone, Debug, Copy, Default, Deserialize)]
124pub enum IsqOrganization {
125    #[default]
126    #[serde(rename = "default")]
127    Default,
128    /// Only quantize MoE experts, if applicable. The enables MoQE.
129    /// <https://arxiv.org/abs/2310.02410>
130    #[serde(rename = "moqe")]
131    MoeExpertsOnly,
132}
133
134impl FromStr for IsqOrganization {
135    type Err = String;
136    fn from_str(s: &str) -> Result<Self, Self::Err> {
137        match s {
138            "default" => Ok(Self::Default),
139            "moqe" => Ok(Self::MoeExpertsOnly),
140            other => Err(format!(
141                "Expected ISQ organization `default` or `moqe`, got `{other}`"
142            )),
143        }
144    }
145}
146
147pub struct UqffFullSer<'a> {
148    pub tokenizer: &'a Tokenizer,
149    pub template_filename: &'a Option<PathBuf>,
150    pub generation_config: Option<&'a PathBuf>,
151    pub config: String,
152    pub processor_filename: &'a Option<PathBuf>,
153    pub preprocessor_filename: &'a Option<PathBuf>,
154}
155
156#[derive(Debug, Clone, Copy)]
157pub enum ImatrixDataSource<'a> {
158    File(&'a PathBuf),
159    Collected,
160}
161
162pub trait IsqModel {
163    /// Corresponds to `IsqOrganization::Default`
164    #[allow(clippy::type_complexity)]
165    fn get_layers(
166        &mut self,
167    ) -> (
168        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
169        &dyn DeviceMapper,
170    );
171
172    /// This is used for imatrix generation internally. Begin stats tracking.
173    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
174        let layers = self
175            .get_layers()
176            .0
177            .into_iter()
178            .map(|(layer, _)| layer)
179            .collect::<Vec<_>>();
180        for layer in layers {
181            Arc::get_mut(layer).unwrap().begin_track_stats()?;
182        }
183        Ok(())
184    }
185
186    /// End stats tracking and return the imatrix data
187    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
188        let layers = self
189            .get_layers()
190            .0
191            .into_iter()
192            .enumerate()
193            .map(|(i, (layer, _))| (i, layer))
194            .collect::<Vec<_>>();
195        let mut data = HashMap::new();
196        for (i, layer) in layers {
197            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
198        }
199        Ok(CollectedImatrixData(data))
200    }
201
202    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
203    /// https://arxiv.org/abs/2310.02410
204    #[allow(clippy::type_complexity)]
205    fn get_layers_moe_experts_only(
206        &mut self,
207    ) -> (
208        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
209        &dyn DeviceMapper,
210    ) {
211        self.get_layers()
212    }
213
214    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
215    /// This is used for imatrix generation internally. Begin stats tracking.
216    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
217        let layers = self
218            .get_layers()
219            .0
220            .into_iter()
221            .map(|(layer, _)| layer)
222            .collect::<Vec<_>>();
223        for layer in layers {
224            Arc::get_mut(layer).unwrap().begin_track_stats()?;
225        }
226        Ok(())
227    }
228
229    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
230    /// End stats tracking and return the imatrix data
231    fn extract_imatrix_data_moe_experts_only(
232        &mut self,
233    ) -> candle_core::Result<CollectedImatrixData> {
234        let layers = self
235            .get_layers()
236            .0
237            .into_iter()
238            .enumerate()
239            .map(|(i, (layer, _))| (i, layer))
240            .collect::<Vec<_>>();
241        let mut data = HashMap::new();
242        for (i, layer) in layers {
243            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
244        }
245        Ok(CollectedImatrixData(data))
246    }
247
248    /// Corresponding to the specific order the model produces ISQ layers (None means
249    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
250    /// corresponding imatrix weights.
251    ///
252    /// - This is only for loading from a llama.cpp imatrix file.
253    /// - Corresponds to `IsqOrganization::Default`
254    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
255        // TODO: make this required.
256        candle_core::bail!("This model does not support quantizing with an imatrix.");
257    }
258
259    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
260    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
261
262    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
263    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
264        None
265    }
266
267    /// Quantize the model in-situ.
268    ///
269    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
270    /// a full serialization is created.
271    #[allow(clippy::too_many_arguments)]
272    fn quantize(
273        &mut self,
274        dtype: Option<IsqType>,
275        device: Device,
276        topology: Option<&Topology>,
277        silent: bool,
278        imatrix_source: Option<ImatrixDataSource<'_>>,
279        organization: IsqOrganization,
280        write_artifacts: Option<&PathBuf>,
281        full_ser: UqffFullSer<'_>,
282        multi_progress: Arc<MultiProgress>,
283    ) -> candle_core::Result<()> {
284        {
285            let imatrix_to_weight = match imatrix_source {
286                Some(ImatrixDataSource::File(imatrix)) => {
287                    let ext = imatrix.extension().ok_or(candle_core::Error::msg(
288                        "Expected an extension for the imatrix source file.",
289                    ))?;
290                    if ext == "cimatrix" {
291                        info!(
292                            "Loading collected imatrix source file: `{}`",
293                            imatrix.display()
294                        );
295                        Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
296                    } else if ext == "imatrix" {
297                        info!(
298                            "Loading GGUF-format imatrix source file: `{}`",
299                            imatrix.display()
300                        );
301                        let mut imatrix_data =
302                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
303                        let imatrix_mapping = self
304                            .imatrix_names()?
305                            .into_iter()
306                            .enumerate()
307                            .collect::<HashMap<_, _>>();
308
309                        let layer_to_weight = imatrix_mapping
310                            .into_iter()
311                            .map(|(i, name)| {
312                                if let Some(name) = name {
313                                    (i, Some(imatrix_data.remove(&name).unwrap()))
314                                } else {
315                                    (i, None)
316                                }
317                            })
318                            .collect::<HashMap<_, _>>();
319                        info!(
320                            "Quantizing with imatrix file `{}`, {} imatrix weights",
321                            imatrix.display(),
322                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
323                        );
324                        Some(layer_to_weight)
325                    } else {
326                        warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
327                        info!(
328                            "Loading GGUF-format imatrix source file: `{}`",
329                            imatrix.display()
330                        );
331
332                        let mut imatrix_data =
333                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
334                        let imatrix_mapping = self
335                            .imatrix_names()?
336                            .into_iter()
337                            .enumerate()
338                            .collect::<HashMap<_, _>>();
339
340                        let layer_to_weight = imatrix_mapping
341                            .into_iter()
342                            .map(|(i, name)| {
343                                if let Some(name) = name {
344                                    (i, Some(imatrix_data.remove(&name).unwrap()))
345                                } else {
346                                    (i, None)
347                                }
348                            })
349                            .collect::<HashMap<_, _>>();
350                        info!(
351                            "Quantizing with imatrix file `{}`, {} imatrix weights",
352                            imatrix.display(),
353                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
354                        );
355                        Some(layer_to_weight)
356                    }
357                }
358                Some(ImatrixDataSource::Collected) => {
359                    let data = match organization {
360                        IsqOrganization::Default => self.extract_imatrix_data()?,
361                        IsqOrganization::MoeExpertsOnly => {
362                            self.extract_imatrix_data_moe_experts_only()?
363                        }
364                    };
365                    // Save the collected imatrix data so users can reuse it
366                    let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
367                    let save_path = format!("collected-{count}.cimatrix");
368                    info!("Saving collected imatrix data to `{save_path}`");
369                    data.save_imatrix(save_path)?;
370                    info!("Quantizing with collected imatrix data, {count} imatrix weights");
371                    Some(data.0)
372                }
373                None => {
374                    // Dummy, just for zip
375                    None
376                }
377            };
378
379            let (mut tensors, mapper) = match organization {
380                IsqOrganization::Default => self.get_layers(),
381                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
382            };
383
384            let imatrix_to_weight: Vec<Option<Vec<f32>>> =
385                if let Some(mut imatrix_to_weight) = imatrix_to_weight {
386                    let ordered_keys = imatrix_to_weight
387                        .keys()
388                        .copied()
389                        .sorted()
390                        .collect::<Vec<_>>();
391                    ordered_keys
392                        .into_iter()
393                        .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
394                        .collect()
395                } else {
396                    vec![None; tensors.len()]
397                };
398
399            let total_tensors = tensors.len();
400            let n_quantized = AtomicUsize::new(0);
401            if let Some(topology) = topology {
402                let mut dtypes = HashSet::new();
403                for layer in topology.0.iter().flatten() {
404                    if let LayerTopology {
405                        isq: Some(isq_dtype),
406                        device: _,
407                    } = layer
408                    {
409                        dtypes.insert(isq_dtype);
410                    }
411                }
412                info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
413            } else {
414                info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
415            }
416            let bar = ProgressBar::new(total_tensors as u64);
417            bar.set_style(
418                ProgressStyle::default_bar()
419                    .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
420                    .unwrap()
421                    .progress_chars("#>-"),
422            );
423            multi_progress.add(bar.clone());
424
425            let layers = topology.map(|x| {
426                x.0.iter()
427                    .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
428                    .collect::<Vec<_>>()
429            });
430
431            let mut devices_and_dtypes = Vec::new();
432            for (_, layer_num) in &tensors {
433                let device = if let Some(ref layers) = layers {
434                    if let Some(layer) = layer_num {
435                        layers
436                            .get(*layer)
437                            .as_ref()
438                            .map(|x| x.1.clone())
439                            .unwrap_or(Some(device.clone()))
440                            .unwrap_or(device.clone())
441                    } else {
442                        device.clone()
443                    }
444                } else if let Some(layer_num) = layer_num {
445                    mapper
446                        .device_for(*layer_num, false)
447                        .cloned()
448                        .unwrap_or(device.clone())
449                } else {
450                    device.clone()
451                };
452                let dtype = if let Some(ref layers) = layers {
453                    if let Some(layer) = layer_num {
454                        layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
455                    } else {
456                        dtype
457                    }
458                } else {
459                    dtype
460                };
461                devices_and_dtypes.push((device, dtype));
462            }
463
464            let t_start = Instant::now();
465
466            use rayon::iter::IntoParallelRefIterator;
467
468            // Get the MINIMUM of the max isq threads the quant method
469            let mut minimum_max_threads = {
470                let current_rayon_threads = rayon::current_num_threads();
471                if let Some(dtype) = dtype {
472                    dtype
473                        .get_max_isq_cpu_threads()
474                        .map(usize::from)
475                        .unwrap_or(current_rayon_threads)
476                } else {
477                    current_rayon_threads
478                }
479            };
480            if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
481                minimum_max_threads = 1;
482            }
483
484            if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
485                // Collected imatrix means that the model is potentially on the gpu already
486                minimum_max_threads = 1;
487            }
488
489            info!("Applying ISQ on {minimum_max_threads} threads.");
490
491            let pool = rayon::ThreadPoolBuilder::new()
492                .num_threads(minimum_max_threads)
493                .build()
494                .map_err(candle_core::Error::msg)?;
495
496            let guard = QuantizeOntoGuard::new();
497
498            pool.install(|| {
499                use indicatif::ParallelProgressIterator;
500                use rayon::iter::{
501                    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
502                };
503                if silent {
504                    tensors
505                        .par_iter_mut()
506                        .zip(devices_and_dtypes)
507                        .zip(imatrix_to_weight)
508                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
509                            **tensor = tensor
510                                .clone()
511                                .apply_isq(
512                                    dtype,
513                                    device.clone(),
514                                    &n_quantized,
515                                    imatrix_weight,
516                                    guard.clone(),
517                                )
518                                .unwrap();
519                            device.synchronize().unwrap();
520                        });
521                } else {
522                    tensors
523                        .par_iter_mut()
524                        .zip(devices_and_dtypes)
525                        .zip(imatrix_to_weight)
526                        .progress_with(bar)
527                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
528                            **tensor = tensor
529                                .clone()
530                                .apply_isq(
531                                    dtype,
532                                    device.clone(),
533                                    &n_quantized,
534                                    imatrix_weight,
535                                    guard.clone(),
536                                )
537                                .unwrap();
538                            device.synchronize().unwrap();
539                        });
540                }
541            });
542
543            if let Some(serialized) = write_artifacts {
544                info!(
545                    "Serializing {total_tensors} ISQ tensors to `{}`.",
546                    serialized.display()
547                );
548
549                if serialized.extension().is_none_or(|ext| ext != "uqff") {
550                    candle_core::bail!("UQFF output path extension must be `.uqff`",);
551                }
552
553                let bar = ProgressBar::new(total_tensors as u64);
554                bar.set_style(
555                    ProgressStyle::default_bar()
556                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
557                        .unwrap()
558                        .progress_chars("#>-"),
559                );
560
561                #[cfg(not(feature = "metal"))]
562                let n_threads = 2;
563                #[cfg(feature = "metal")]
564                let n_threads = 1;
565
566                let pool = rayon::ThreadPoolBuilder::new()
567                    .num_threads(n_threads)
568                    .build()
569                    .map_err(candle_core::Error::msg)?;
570
571                let quantized_values = pool.install(|| {
572                    if silent {
573                        tensors
574                            .par_iter()
575                            .enumerate()
576                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
577                            .map(|(i, (layer, _))| {
578                                Ok((
579                                    i.to_string(),
580                                    Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
581                                ))
582                            })
583                            .collect::<candle_core::Result<Vec<_>>>()
584                    } else {
585                        tensors
586                            .par_iter()
587                            .enumerate()
588                            .progress_with(bar)
589                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
590                            .map(|(i, (layer, _))| {
591                                Ok((
592                                    i.to_string(),
593                                    Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?,
594                                ))
595                            })
596                            .collect::<candle_core::Result<Vec<_>>>()
597                    }
598                });
599
600                let parent = serialized
601                    .parent()
602                    .context("Target UQFF path must have a filename!")?;
603
604                std::fs::create_dir_all(parent)?;
605
606                safetensors::serialize_to_file(quantized_values?, &None, serialized)?;
607
608                let residual = match organization {
609                    IsqOrganization::Default => self.residual_tensors(),
610                    IsqOrganization::MoeExpertsOnly => self
611                        .residual_tensors_moe_experts_only()
612                        .unwrap_or(self.residual_tensors()),
613                };
614
615                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
616                let config_out = parent.join("config.json");
617                let tokenizer_out = parent.join("tokenizer.json");
618                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
619                let gen_cfg_out = parent.join("generation_config.json");
620                let processor_out = parent.join("processor_config.json");
621                let preprocessor_out = parent.join("preprocessor_config.json");
622
623                info!(
624                    "Serializing {} residual tensors to `{}`.",
625                    residual.len(),
626                    residual_out.display()
627                );
628
629                safetensors::serialize_to_file(residual, &None, &residual_out)?;
630
631                let UqffFullSer {
632                    tokenizer,
633                    template_filename,
634                    generation_config,
635                    config,
636                    processor_filename,
637                    preprocessor_filename,
638                } = full_ser;
639
640                info!("Serializing configuration to `{}`.", config_out.display());
641
642                std::fs::write(config_out, config)?;
643
644                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
645
646                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
647                    .map_err(candle_core::Error::msg)?;
648
649                if let Some(template_filename) = template_filename {
650                    info!(
651                        "Serializing tokenizer config to `{}`.",
652                        tokenizer_cfg_out.display()
653                    );
654
655                    let template =
656                        std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
657                    std::fs::write(&tokenizer_cfg_out, template)
658                        .map_err(candle_core::Error::msg)?;
659                }
660
661                if let Some(generation_config) = generation_config {
662                    info!(
663                        "Serializing generation config to `{}`.",
664                        gen_cfg_out.display()
665                    );
666
667                    let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
668                    std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
669                }
670
671                if let Some(processor_config) = processor_filename {
672                    info!(
673                        "Serializing processor config to `{}`.",
674                        processor_out.display()
675                    );
676
677                    let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
678                    std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
679                }
680
681                if let Some(preprocessor_config) = preprocessor_filename {
682                    info!(
683                        "Serializing preprocessor config to `{}`.",
684                        preprocessor_out.display()
685                    );
686
687                    let cfg =
688                        std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
689                    std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
690                }
691            }
692            let delta = Instant::now().duration_since(t_start).as_secs_f32();
693            info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
694        }
695        Ok(())
696    }
697
698    fn load_from_artifacts(
699        &mut self,
700        device: Device,
701        topology: Option<&Topology>,
702        silent: bool,
703        artifacts: &PathBuf,
704    ) -> candle_core::Result<()> {
705        let (tensors, mapper) = self.get_layers();
706        let total_tensors = tensors.len();
707
708        let layers = topology.map(|x| {
709            x.0.iter()
710                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
711                .collect::<Vec<_>>()
712        });
713
714        let mut devices = Vec::new();
715        let mut comms = Vec::new();
716        for (_, layer_num) in &tensors {
717            let device = if let Some(ref layers) = layers {
718                if let Some(layer) = layer_num {
719                    layers
720                        .get(*layer)
721                        .as_ref()
722                        .map(|x| x.1.clone())
723                        .unwrap_or(Some(device.clone()))
724                        .unwrap_or(device.clone())
725                } else {
726                    device.clone()
727                }
728            } else if let Some(layer_num) = layer_num {
729                mapper
730                    .device_for(*layer_num, false)
731                    .cloned()
732                    .unwrap_or(device.clone())
733            } else {
734                device.clone()
735            };
736            devices.push(device);
737            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
738        }
739
740        let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::new(artifacts)? };
741
742        let artifact_isqs = artifacts
743            .tensors()
744            .into_iter()
745            .map(|(name, tensor)| {
746                (
747                    name.parse::<usize>()
748                        .expect("Name should be parseable as usize"),
749                    tensor,
750                )
751            })
752            .collect::<HashMap<_, _>>();
753
754        if artifact_isqs.len() != total_tensors {
755            candle_core::bail!(
756                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
757                artifact_isqs.len(),
758            );
759        }
760
761        let bar = ProgressBar::new(total_tensors as u64);
762        bar.set_style(
763            ProgressStyle::default_bar()
764                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
765                .unwrap()
766                .progress_chars("#>-"),
767        );
768
769        let t_start = Instant::now();
770
771        let guard = QuantizeOntoGuard::new();
772
773        if silent {
774            (0..tensors.len())
775                .into_par_iter()
776                .zip(tensors)
777                .map(|(i, (tensor, _))| {
778                    if let Some(artifact) = artifact_isqs.get(&i) {
779                        let artifact = artifact.data();
780
781                        let comm = comms[i].clone();
782                        let deserialized = match tensor.is_distributed() {
783                            Some(DistributedKind::ColumnParallel) => {
784                                ColumnParallelLayer::deserialize(
785                                    Cow::from(artifact),
786                                    &devices[i],
787                                    &comm,
788                                    guard.clone(),
789                                )?
790                            }
791                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
792                                Cow::from(artifact),
793                                &devices[i],
794                                &comm,
795                                guard.clone(),
796                            )?,
797                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
798                                Cow::from(artifact),
799                                &devices[i],
800                                &comm,
801                                guard.clone(),
802                            )?,
803                            None => {
804                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
805                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
806                                match QuantizedSerdeType::try_from(isq_type as usize)? {
807                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
808                                        Cow::from(artifact),
809                                        &devices[i],
810                                        &comm,
811                                        guard.clone(),
812                                    )?,
813                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
814                                        Cow::from(artifact),
815                                        &devices[i],
816                                        &comm,
817                                        guard.clone(),
818                                    )?,
819                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
820                                        Cow::from(artifact),
821                                        &devices[i],
822                                        &comm,
823                                        guard.clone(),
824                                    )?,
825                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
826                                        Cow::from(artifact),
827                                        &devices[i],
828                                        &comm,
829                                        guard.clone(),
830                                    )?,
831                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
832                                        Cow::from(artifact),
833                                        &devices[i],
834                                        &comm,
835                                        guard.clone(),
836                                    )?,
837                                }
838                            }
839                        };
840                        *tensor = deserialized;
841                    }
842                    Ok(())
843                })
844                .collect::<candle_core::Result<Vec<_>>>()?;
845        } else {
846            (0..tensors.len())
847                .into_par_iter()
848                .zip(tensors)
849                .progress_with(bar)
850                .map(|(i, (tensor, _))| {
851                    if let Some(artifact) = artifact_isqs.get(&i) {
852                        let artifact = artifact.data();
853
854                        let comm = comms[i].clone();
855                        let deserialized = match tensor.is_distributed() {
856                            Some(DistributedKind::ColumnParallel) => {
857                                ColumnParallelLayer::deserialize(
858                                    Cow::from(artifact),
859                                    &devices[i],
860                                    &comm,
861                                    guard.clone(),
862                                )?
863                            }
864                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
865                                Cow::from(artifact),
866                                &devices[i],
867                                &comm,
868                                guard.clone(),
869                            )?,
870                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
871                                Cow::from(artifact),
872                                &devices[i],
873                                &comm,
874                                guard.clone(),
875                            )?,
876                            None => {
877                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
878                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
879                                match QuantizedSerdeType::try_from(isq_type as usize)? {
880                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
881                                        Cow::from(artifact),
882                                        &devices[i],
883                                        &comm,
884                                        guard.clone(),
885                                    )?,
886                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
887                                        Cow::from(artifact),
888                                        &devices[i],
889                                        &comm,
890                                        guard.clone(),
891                                    )?,
892                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
893                                        Cow::from(artifact),
894                                        &devices[i],
895                                        &comm,
896                                        guard.clone(),
897                                    )?,
898                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
899                                        Cow::from(artifact),
900                                        &devices[i],
901                                        &comm,
902                                        guard.clone(),
903                                    )?,
904                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
905                                        Cow::from(artifact),
906                                        &devices[i],
907                                        &comm,
908                                        guard.clone(),
909                                    )?,
910                                }
911                            }
912                        };
913                        *tensor = deserialized;
914                    }
915                    Ok(())
916                })
917                .collect::<candle_core::Result<Vec<_>>>()?;
918        }
919
920        let delta = Instant::now().duration_since(t_start).as_secs_f32();
921        info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
922
923        Ok(())
924    }
925}
926
927/// Trait for loading models with ISQ.
928pub(crate) trait IsqModelLoader {
929    /// Regex to match layers which will have standard ISQ applied.
930    ///
931    /// Only called on non-adapter models!
932    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
933        Ok(Vec::new())
934    }
935
936    /// Regex to match layers which will have standard MoQE ISQ applied.
937    ///
938    /// Only called on non-adapter models!
939    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
940        self.isq_layer_regexes(config)
941    }
942}