mistralrs_core/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11/// Wrapper around a `Cow<'a, [u8]>` buffer that implements
12/// `safetensors::tensor::View`.
13///
14/// *Purpose*: lets us pass raw byte buffers to
15/// `safetensors::serialize_to_file` without cloning them into a `Vec<u8>` or
16/// converting to a higher‑level tensor type.  
17/// We expose the buffer as a 1‑D `u8` tensor of shape `[len]`.
18#[derive(Clone)]
19pub struct CowBytesView<'a> {
20    data: Cow<'a, [u8]>,
21    shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25    /// Convenience constructor.
26    pub fn new(data: Cow<'a, [u8]>) -> Self {
27        let len = data.len();
28        Self { data, shape: [len] }
29    }
30}
31
32impl<'a> safetensors::tensor::View for CowBytesView<'a> {
33    fn dtype(&self) -> safetensors::tensor::Dtype {
34        // Serialize as raw bytes
35        safetensors::tensor::Dtype::U8
36    }
37
38    fn shape(&self) -> &[usize] {
39        &self.shape
40    }
41
42    fn data(&self) -> Cow<[u8]> {
43        assert!(matches!(self.data, Cow::Borrowed(_)));
44        // Cloning a `Cow` is cheap (only clones the enum, not the data).
45        self.data.clone()
46    }
47
48    fn data_len(&self) -> usize {
49        self.data.len()
50    }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59    HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60    ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
69
70pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
71// 10 GB max per file
72const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
73pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
74
75/// Parse ISQ value.
76///
77/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
78/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
79///
80/// One of:
81/// - `Q4_0`
82/// - `Q4_1`
83/// - `Q5_0`
84/// - `Q5_1`
85/// - `Q8_0`
86/// - `Q8_1`
87/// - `Q2K`
88/// - `Q3K`
89/// - `Q4K`
90/// - `Q5K`
91/// - `Q6K`
92/// - `Q8K`
93/// - `HQQ1`
94/// - `HQQ2`
95/// - `HQQ3`
96/// - `HQQ4`
97/// - `HQQ8`
98/// - `AFQ2`
99/// - `AFQ3`
100/// - `AFQ4`
101/// - `AFQ6`
102/// - `AFQ8`
103pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
104    let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
105    let tp = match s.to_lowercase().as_str() {
106        "2" if is_metal => IsqType::AFQ2,
107        "2" if !is_metal => IsqType::Q2K,
108        "3" if is_metal => IsqType::AFQ3,
109        "3" if !is_metal => IsqType::Q3K,
110        "4" if is_metal => IsqType::AFQ4,
111        "4" if !is_metal => IsqType::Q4K,
112        "5" => IsqType::Q5K,
113        "6" if is_metal => IsqType::AFQ6,
114        "6" if !is_metal => IsqType::Q6K,
115        "8" if is_metal => IsqType::AFQ8,
116        "8" if !is_metal => IsqType::Q8_0,
117        "q4_0" => IsqType::Q4_0,
118        "q4_1" => IsqType::Q4_1,
119        "q5_0" => IsqType::Q5_0,
120        "q5_1" => IsqType::Q5_1,
121        "q8_0" => IsqType::Q8_0,
122        "q8_1" => IsqType::Q8_1,
123        "q2k" => IsqType::Q2K,
124        "q3k" => IsqType::Q3K,
125        "q4k" => IsqType::Q4K,
126        "q5k" => IsqType::Q5K,
127        "q6k" => IsqType::Q6K,
128        "q8k" => IsqType::Q8K,
129        "hqq8" => IsqType::HQQ8,
130        "hqq4" => IsqType::HQQ4,
131        "fp8" => IsqType::F8E4M3,
132        "afq8" => IsqType::AFQ8,
133        "afq6" => IsqType::AFQ6,
134        "afq4" => IsqType::AFQ4,
135        "afq3" => IsqType::AFQ3,
136        "afq2" => IsqType::AFQ2,
137        // "hqq3" => IsqType::HQQ3,
138        // "hqq2" => IsqType::HQQ2,
139        // "hqq1" => IsqType::HQQ1,
140        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
141    };
142    #[cfg(feature = "cuda")]
143    {
144        if !matches!(
145            tp,
146            IsqType::Q4_0
147                | IsqType::Q4_1
148                | IsqType::Q5_0
149                | IsqType::Q5_1
150                | IsqType::Q8_0
151                | IsqType::Q2K
152                | IsqType::Q3K
153                | IsqType::Q4K
154                | IsqType::Q5K
155                | IsqType::Q6K
156                | IsqType::HQQ8
157                | IsqType::HQQ4
158                | IsqType::F8E4M3 // | IsqType::HQQ3
159                                  // | IsqType::HQQ2
160                                  // | IsqType::HQQ1
161        ) {
162            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
163        }
164    }
165    Ok(tp)
166}
167
168#[derive(Clone, Debug, Copy, Default, Deserialize)]
169pub enum IsqOrganization {
170    #[default]
171    #[serde(rename = "default")]
172    Default,
173    /// Only quantize MoE experts, if applicable. The enables MoQE.
174    /// <https://arxiv.org/abs/2310.02410>
175    #[serde(rename = "moqe")]
176    MoeExpertsOnly,
177}
178
179impl FromStr for IsqOrganization {
180    type Err = String;
181    fn from_str(s: &str) -> Result<Self, Self::Err> {
182        match s {
183            "default" => Ok(Self::Default),
184            "moqe" => Ok(Self::MoeExpertsOnly),
185            other => Err(format!(
186                "Expected ISQ organization `default` or `moqe`, got `{other}`"
187            )),
188        }
189    }
190}
191
192pub struct UqffFullSer<'a> {
193    pub tokenizer: &'a Tokenizer,
194    pub template_filename: &'a Option<PathBuf>,
195    pub generation_config: Option<&'a PathBuf>,
196    pub config: String,
197    pub processor_filename: &'a Option<PathBuf>,
198    pub preprocessor_filename: &'a Option<PathBuf>,
199}
200
201#[derive(Debug, Clone, Copy)]
202pub enum ImatrixDataSource<'a> {
203    File(&'a PathBuf),
204    Collected,
205}
206
207pub trait IsqModel {
208    /// Corresponds to `IsqOrganization::Default`
209    #[allow(clippy::type_complexity)]
210    fn get_layers(
211        &mut self,
212    ) -> (
213        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
214        &dyn DeviceMapper,
215    );
216
217    /// This is used for imatrix generation internally. Begin stats tracking.
218    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
219        let layers = self
220            .get_layers()
221            .0
222            .into_iter()
223            .map(|(layer, _)| layer)
224            .collect::<Vec<_>>();
225        for layer in layers {
226            Arc::get_mut(layer).unwrap().begin_track_stats()?;
227        }
228        Ok(())
229    }
230
231    /// End stats tracking and return the imatrix data
232    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
233        let layers = self
234            .get_layers()
235            .0
236            .into_iter()
237            .enumerate()
238            .map(|(i, (layer, _))| (i, layer))
239            .collect::<Vec<_>>();
240        let mut data = HashMap::new();
241        for (i, layer) in layers {
242            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
243        }
244        Ok(CollectedImatrixData(data))
245    }
246
247    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
248    /// https://arxiv.org/abs/2310.02410
249    #[allow(clippy::type_complexity)]
250    fn get_layers_moe_experts_only(
251        &mut self,
252    ) -> (
253        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
254        &dyn DeviceMapper,
255    ) {
256        self.get_layers()
257    }
258
259    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
260    /// This is used for imatrix generation internally. Begin stats tracking.
261    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
262        let layers = self
263            .get_layers()
264            .0
265            .into_iter()
266            .map(|(layer, _)| layer)
267            .collect::<Vec<_>>();
268        for layer in layers {
269            Arc::get_mut(layer).unwrap().begin_track_stats()?;
270        }
271        Ok(())
272    }
273
274    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
275    /// End stats tracking and return the imatrix data
276    fn extract_imatrix_data_moe_experts_only(
277        &mut self,
278    ) -> candle_core::Result<CollectedImatrixData> {
279        let layers = self
280            .get_layers()
281            .0
282            .into_iter()
283            .enumerate()
284            .map(|(i, (layer, _))| (i, layer))
285            .collect::<Vec<_>>();
286        let mut data = HashMap::new();
287        for (i, layer) in layers {
288            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
289        }
290        Ok(CollectedImatrixData(data))
291    }
292
293    /// Corresponding to the specific order the model produces ISQ layers (None means
294    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
295    /// corresponding imatrix weights.
296    ///
297    /// - This is only for loading from a llama.cpp imatrix file.
298    /// - Corresponds to `IsqOrganization::Default`
299    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
300        // TODO: make this required.
301        candle_core::bail!("This model does not support quantizing with an imatrix.");
302    }
303
304    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
305    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
306
307    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
308    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
309        None
310    }
311
312    /// Quantize the model in-situ.
313    ///
314    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
315    /// a full serialization is created.
316    #[allow(clippy::too_many_arguments)]
317    fn quantize(
318        &mut self,
319        dtype: Option<IsqType>,
320        device: Device,
321        topology: Option<&Topology>,
322        silent: bool,
323        imatrix_source: Option<ImatrixDataSource<'_>>,
324        organization: IsqOrganization,
325        write_artifacts: Option<&PathBuf>,
326        full_ser: UqffFullSer<'_>,
327        multi_progress: Arc<MultiProgress>,
328    ) -> candle_core::Result<()> {
329        {
330            let imatrix_to_weight = match imatrix_source {
331                Some(ImatrixDataSource::File(imatrix)) => {
332                    let ext = imatrix.extension().ok_or(candle_core::Error::msg(
333                        "Expected an extension for the imatrix source file.",
334                    ))?;
335                    if ext == "cimatrix" {
336                        info!(
337                            "Loading collected imatrix source file: `{}`",
338                            imatrix.display()
339                        );
340                        Some(CollectedImatrixData::load_imatrix(imatrix)?.0)
341                    } else if ext == "imatrix" {
342                        info!(
343                            "Loading GGUF-format imatrix source file: `{}`",
344                            imatrix.display()
345                        );
346                        let mut imatrix_data =
347                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
348                        let imatrix_mapping = self
349                            .imatrix_names()?
350                            .into_iter()
351                            .enumerate()
352                            .collect::<HashMap<_, _>>();
353
354                        let layer_to_weight = imatrix_mapping
355                            .into_iter()
356                            .map(|(i, name)| {
357                                if let Some(name) = name {
358                                    (i, Some(imatrix_data.remove(&name).unwrap()))
359                                } else {
360                                    (i, None)
361                                }
362                            })
363                            .collect::<HashMap<_, _>>();
364                        info!(
365                            "Quantizing with imatrix file `{}`, {} imatrix weights",
366                            imatrix.display(),
367                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
368                        );
369                        Some(layer_to_weight)
370                    } else {
371                        warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
372                        info!(
373                            "Loading GGUF-format imatrix source file: `{}`",
374                            imatrix.display()
375                        );
376
377                        let mut imatrix_data =
378                            quantized::imatrix_file::load_imatrix(imatrix.clone())?;
379                        let imatrix_mapping = self
380                            .imatrix_names()?
381                            .into_iter()
382                            .enumerate()
383                            .collect::<HashMap<_, _>>();
384
385                        let layer_to_weight = imatrix_mapping
386                            .into_iter()
387                            .map(|(i, name)| {
388                                if let Some(name) = name {
389                                    (i, Some(imatrix_data.remove(&name).unwrap()))
390                                } else {
391                                    (i, None)
392                                }
393                            })
394                            .collect::<HashMap<_, _>>();
395                        info!(
396                            "Quantizing with imatrix file `{}`, {} imatrix weights",
397                            imatrix.display(),
398                            layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
399                        );
400                        Some(layer_to_weight)
401                    }
402                }
403                Some(ImatrixDataSource::Collected) => {
404                    let data = match organization {
405                        IsqOrganization::Default => self.extract_imatrix_data()?,
406                        IsqOrganization::MoeExpertsOnly => {
407                            self.extract_imatrix_data_moe_experts_only()?
408                        }
409                    };
410                    // Save the collected imatrix data so users can reuse it
411                    let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
412                    let save_path = format!("collected-{count}.cimatrix");
413                    info!("Saving collected imatrix data to `{save_path}`");
414                    data.save_imatrix(save_path)?;
415                    info!("Quantizing with collected imatrix data, {count} imatrix weights");
416                    Some(data.0)
417                }
418                None => {
419                    // Dummy, just for zip
420                    None
421                }
422            };
423
424            let (mut tensors, mapper) = match organization {
425                IsqOrganization::Default => self.get_layers(),
426                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
427            };
428
429            let imatrix_to_weight: Vec<Option<Vec<f32>>> =
430                if let Some(mut imatrix_to_weight) = imatrix_to_weight {
431                    let ordered_keys = imatrix_to_weight
432                        .keys()
433                        .copied()
434                        .sorted()
435                        .collect::<Vec<_>>();
436                    ordered_keys
437                        .into_iter()
438                        .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
439                        .collect()
440                } else {
441                    vec![None; tensors.len()]
442                };
443
444            let total_tensors = tensors.len();
445            let n_quantized = AtomicUsize::new(0);
446            if let Some(topology) = topology {
447                let mut dtypes = HashSet::new();
448                for layer in topology.0.iter().flatten() {
449                    if let LayerTopology {
450                        isq: Some(isq_dtype),
451                        device: _,
452                    } = layer
453                    {
454                        dtypes.insert(isq_dtype);
455                    }
456                }
457                info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
458            } else {
459                info!("Applying in-situ quantization into {dtype:?} to {total_tensors} tensors.");
460            }
461            let bar = ProgressBar::new(total_tensors as u64);
462            bar.set_style(
463                ProgressStyle::default_bar()
464                    .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
465                    .unwrap()
466                    .progress_chars("#>-"),
467            );
468            multi_progress.add(bar.clone());
469
470            let layers = topology.map(|x| {
471                x.0.iter()
472                    .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
473                    .collect::<Vec<_>>()
474            });
475
476            let mut devices_and_dtypes = Vec::new();
477            for (_, layer_num) in &tensors {
478                let device = if let Some(ref layers) = layers {
479                    if let Some(layer) = layer_num {
480                        layers
481                            .get(*layer)
482                            .as_ref()
483                            .map(|x| x.1.clone())
484                            .unwrap_or(Some(device.clone()))
485                            .unwrap_or(device.clone())
486                    } else {
487                        device.clone()
488                    }
489                } else if let Some(layer_num) = layer_num {
490                    mapper
491                        .device_for(*layer_num, false)
492                        .cloned()
493                        .unwrap_or(device.clone())
494                } else {
495                    device.clone()
496                };
497                let dtype = if let Some(ref layers) = layers {
498                    if let Some(layer) = layer_num {
499                        layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
500                    } else {
501                        dtype
502                    }
503                } else {
504                    dtype
505                };
506                devices_and_dtypes.push((device, dtype));
507            }
508
509            let t_start = Instant::now();
510
511            use rayon::iter::IntoParallelRefIterator;
512
513            // Get the MINIMUM of the max isq threads the quant method
514            let mut minimum_max_threads = {
515                let current_rayon_threads = rayon::current_num_threads();
516                if let Some(dtype) = dtype {
517                    dtype
518                        .get_max_isq_cpu_threads()
519                        .map(usize::from)
520                        .unwrap_or(current_rayon_threads)
521                } else {
522                    current_rayon_threads
523                }
524            };
525            if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
526                minimum_max_threads = 1;
527            }
528
529            if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
530                // Collected imatrix means that the model is potentially on the gpu already
531                minimum_max_threads = 1;
532            }
533
534            info!("Applying ISQ on {minimum_max_threads} threads.");
535
536            let pool = rayon::ThreadPoolBuilder::new()
537                .num_threads(minimum_max_threads)
538                .build()
539                .map_err(candle_core::Error::msg)?;
540
541            let guard = QuantizeOntoGuard::new();
542
543            pool.install(|| {
544                use indicatif::ParallelProgressIterator;
545                use rayon::iter::{
546                    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
547                };
548                if silent {
549                    tensors
550                        .par_iter_mut()
551                        .zip(devices_and_dtypes)
552                        .zip(imatrix_to_weight)
553                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
554                            **tensor = tensor
555                                .clone()
556                                .apply_isq(
557                                    dtype,
558                                    device.clone(),
559                                    &n_quantized,
560                                    imatrix_weight,
561                                    guard.clone(),
562                                )
563                                .unwrap();
564                            device.synchronize().unwrap();
565                        });
566                } else {
567                    tensors
568                        .par_iter_mut()
569                        .zip(devices_and_dtypes)
570                        .zip(imatrix_to_weight)
571                        .progress_with(bar)
572                        .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
573                            **tensor = tensor
574                                .clone()
575                                .apply_isq(
576                                    dtype,
577                                    device.clone(),
578                                    &n_quantized,
579                                    imatrix_weight,
580                                    guard.clone(),
581                                )
582                                .unwrap();
583                            device.synchronize().unwrap();
584                        });
585                }
586            });
587
588            if let Some(serialized) = write_artifacts {
589                info!(
590                    "Serializing {total_tensors} ISQ tensors to `{}`.",
591                    serialized.display()
592                );
593
594                if serialized.extension().is_none_or(|ext| ext != "uqff") {
595                    candle_core::bail!("UQFF output path extension must be `.uqff`",);
596                }
597
598                let bar = ProgressBar::new(total_tensors as u64);
599                bar.set_style(
600                    ProgressStyle::default_bar()
601                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
602                        .unwrap()
603                        .progress_chars("#>-"),
604                );
605
606                #[cfg(not(feature = "metal"))]
607                let n_threads = 2;
608                #[cfg(feature = "metal")]
609                let n_threads = 1;
610
611                let pool = rayon::ThreadPoolBuilder::new()
612                    .num_threads(n_threads)
613                    .build()
614                    .map_err(candle_core::Error::msg)?;
615
616                let quantized_values = pool.install(|| {
617                    if silent {
618                        tensors
619                            .par_iter()
620                            .enumerate()
621                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
622                            .map(|(i, (layer, _))| {
623                                Ok((
624                                    i.to_string(),
625                                    match layer.serialize()? {
626                                        Cow::Borrowed(_) => unreachable!(),
627                                        Cow::Owned(owned) => owned,
628                                    },
629                                ))
630                            })
631                            .collect::<candle_core::Result<Vec<_>>>()
632                    } else {
633                        tensors
634                            .par_iter()
635                            .enumerate()
636                            .progress_with(bar)
637                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
638                            .map(|(i, (layer, _))| {
639                                Ok((
640                                    i.to_string(),
641                                    match layer.serialize()? {
642                                        Cow::Borrowed(_) => unreachable!(),
643                                        Cow::Owned(owned) => owned,
644                                    },
645                                ))
646                            })
647                            .collect::<candle_core::Result<Vec<_>>>()
648                    }
649                });
650                let quantized_values = quantized_values?;
651
652                let parent = serialized
653                    .parent()
654                    .context("Target UQFF path must have a filename!")?;
655
656                std::fs::create_dir_all(parent)?;
657
658                let file_stem = serialized
659                    .file_stem()
660                    .context("Target UQFF path must have a file stem!")?
661                    .to_string_lossy()
662                    .to_string();
663
664                // Shard quantized values by cumulative byte size, max MAX_UQFF_SIZE_BYTES per file
665                let mut current_chunk = Vec::new();
666                let mut current_bytes: usize = 0;
667                let mut shard_index = 0;
668
669                // Every 10GB, flush the file. Then save any remaining tensors
670                for (name, tensor) in quantized_values.iter() {
671                    let tensor_bytes = tensor.len();
672                    if !current_chunk.is_empty()
673                        && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
674                    {
675                        let mut shard_path = parent.to_path_buf();
676                        shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
677                        info!(
678                            "Writing shard {} to `{}`",
679                            shard_index,
680                            shard_path.display()
681                        );
682                        safetensors::serialize_to_file(current_chunk.clone(), &None, &shard_path)?;
683                        shard_index += 1;
684                        current_chunk.clear();
685                        current_bytes = 0;
686                    }
687                    current_bytes += tensor_bytes;
688                    current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
689                }
690
691                if !current_chunk.is_empty() {
692                    let mut shard_path = parent.to_path_buf();
693                    shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
694                    info!(
695                        "Writing final shard {} to `{}`",
696                        shard_index,
697                        shard_path.display()
698                    );
699                    safetensors::serialize_to_file(current_chunk.clone(), &None, &shard_path)?;
700                }
701
702                let residual = match organization {
703                    IsqOrganization::Default => self.residual_tensors(),
704                    IsqOrganization::MoeExpertsOnly => self
705                        .residual_tensors_moe_experts_only()
706                        .unwrap_or(self.residual_tensors()),
707                };
708
709                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
710                let config_out = parent.join("config.json");
711                let tokenizer_out = parent.join("tokenizer.json");
712                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
713                let gen_cfg_out = parent.join("generation_config.json");
714                let processor_out = parent.join("processor_config.json");
715                let preprocessor_out = parent.join("preprocessor_config.json");
716
717                info!(
718                    "Serializing {} residual tensors to `{}`.",
719                    residual.len(),
720                    residual_out.display()
721                );
722
723                safetensors::serialize_to_file(residual, &None, &residual_out)?;
724
725                let UqffFullSer {
726                    tokenizer,
727                    template_filename,
728                    generation_config,
729                    config,
730                    processor_filename,
731                    preprocessor_filename,
732                } = full_ser;
733
734                info!("Serializing configuration to `{}`.", config_out.display());
735
736                std::fs::write(config_out, config)?;
737
738                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
739
740                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
741                    .map_err(candle_core::Error::msg)?;
742
743                if let Some(template_filename) = template_filename {
744                    info!(
745                        "Serializing tokenizer config to `{}`.",
746                        tokenizer_cfg_out.display()
747                    );
748
749                    let template =
750                        std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
751                    std::fs::write(&tokenizer_cfg_out, template)
752                        .map_err(candle_core::Error::msg)?;
753                }
754
755                if let Some(generation_config) = generation_config {
756                    info!(
757                        "Serializing generation config to `{}`.",
758                        gen_cfg_out.display()
759                    );
760
761                    let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
762                    std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
763                }
764
765                if let Some(processor_config) = processor_filename {
766                    info!(
767                        "Serializing processor config to `{}`.",
768                        processor_out.display()
769                    );
770
771                    let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
772                    std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
773                }
774
775                if let Some(preprocessor_config) = preprocessor_filename {
776                    info!(
777                        "Serializing preprocessor config to `{}`.",
778                        preprocessor_out.display()
779                    );
780
781                    let cfg =
782                        std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
783                    std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
784                }
785            }
786            let delta = Instant::now().duration_since(t_start).as_secs_f32();
787            info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );
788        }
789        Ok(())
790    }
791
792    fn load_from_artifacts(
793        &mut self,
794        device: Device,
795        topology: Option<&Topology>,
796        silent: bool,
797        artifacts: &[PathBuf],
798    ) -> candle_core::Result<()> {
799        let (tensors, mapper) = self.get_layers();
800        let total_tensors = tensors.len();
801
802        let layers = topology.map(|x| {
803            x.0.iter()
804                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
805                .collect::<Vec<_>>()
806        });
807
808        let mut devices = Vec::new();
809        let mut comms = Vec::new();
810        for (_, layer_num) in &tensors {
811            let device = if let Some(ref layers) = layers {
812                if let Some(layer) = layer_num {
813                    layers
814                        .get(*layer)
815                        .as_ref()
816                        .map(|x| x.1.clone())
817                        .unwrap_or(Some(device.clone()))
818                        .unwrap_or(device.clone())
819                } else {
820                    device.clone()
821                }
822            } else if let Some(layer_num) = layer_num {
823                mapper
824                    .device_for(*layer_num, false)
825                    .cloned()
826                    .unwrap_or(device.clone())
827            } else {
828                device.clone()
829            };
830            devices.push(device);
831            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
832        }
833
834        let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
835
836        let artifact_isqs = artifacts
837            .tensors()
838            .into_iter()
839            .map(|(name, tensor)| {
840                (
841                    name.parse::<usize>()
842                        .expect("Name should be parseable as usize"),
843                    tensor,
844                )
845            })
846            .collect::<HashMap<_, _>>();
847
848        if artifact_isqs.len() != total_tensors {
849            candle_core::bail!(
850                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
851                artifact_isqs.len(),
852            );
853        }
854
855        let bar = ProgressBar::new(total_tensors as u64);
856        bar.set_style(
857            ProgressStyle::default_bar()
858                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
859                .unwrap()
860                .progress_chars("#>-"),
861        );
862
863        let t_start = Instant::now();
864
865        let guard = QuantizeOntoGuard::new();
866
867        if silent {
868            (0..tensors.len())
869                .into_par_iter()
870                .zip(tensors)
871                .map(|(i, (tensor, _))| {
872                    if let Some(artifact) = artifact_isqs.get(&i) {
873                        let artifact = artifact.data();
874
875                        let comm = comms[i].clone();
876                        let deserialized = match tensor.is_distributed() {
877                            Some(DistributedKind::ColumnParallel) => {
878                                ColumnParallelLayer::deserialize(
879                                    Cow::from(artifact),
880                                    &devices[i],
881                                    &comm,
882                                    guard.clone(),
883                                )?
884                            }
885                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
886                                Cow::from(artifact),
887                                &devices[i],
888                                &comm,
889                                guard.clone(),
890                            )?,
891                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
892                                Cow::from(artifact),
893                                &devices[i],
894                                &comm,
895                                guard.clone(),
896                            )?,
897                            None => {
898                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
899                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
900                                match QuantizedSerdeType::try_from(isq_type as usize)? {
901                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
902                                        Cow::from(artifact),
903                                        &devices[i],
904                                        &comm,
905                                        guard.clone(),
906                                    )?,
907                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
908                                        Cow::from(artifact),
909                                        &devices[i],
910                                        &comm,
911                                        guard.clone(),
912                                    )?,
913                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
914                                        Cow::from(artifact),
915                                        &devices[i],
916                                        &comm,
917                                        guard.clone(),
918                                    )?,
919                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
920                                        Cow::from(artifact),
921                                        &devices[i],
922                                        &comm,
923                                        guard.clone(),
924                                    )?,
925                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
926                                        Cow::from(artifact),
927                                        &devices[i],
928                                        &comm,
929                                        guard.clone(),
930                                    )?,
931                                }
932                            }
933                        };
934                        *tensor = deserialized;
935                    }
936                    Ok(())
937                })
938                .collect::<candle_core::Result<Vec<_>>>()?;
939        } else {
940            (0..tensors.len())
941                .into_par_iter()
942                .zip(tensors)
943                .progress_with(bar)
944                .map(|(i, (tensor, _))| {
945                    if let Some(artifact) = artifact_isqs.get(&i) {
946                        let artifact = artifact.data();
947
948                        let comm = comms[i].clone();
949                        let deserialized = match tensor.is_distributed() {
950                            Some(DistributedKind::ColumnParallel) => {
951                                ColumnParallelLayer::deserialize(
952                                    Cow::from(artifact),
953                                    &devices[i],
954                                    &comm,
955                                    guard.clone(),
956                                )?
957                            }
958                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
959                                Cow::from(artifact),
960                                &devices[i],
961                                &comm,
962                                guard.clone(),
963                            )?,
964                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
965                                Cow::from(artifact),
966                                &devices[i],
967                                &comm,
968                                guard.clone(),
969                            )?,
970                            None => {
971                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
972                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
973                                match QuantizedSerdeType::try_from(isq_type as usize)? {
974                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
975                                        Cow::from(artifact),
976                                        &devices[i],
977                                        &comm,
978                                        guard.clone(),
979                                    )?,
980                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
981                                        Cow::from(artifact),
982                                        &devices[i],
983                                        &comm,
984                                        guard.clone(),
985                                    )?,
986                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
987                                        Cow::from(artifact),
988                                        &devices[i],
989                                        &comm,
990                                        guard.clone(),
991                                    )?,
992                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
993                                        Cow::from(artifact),
994                                        &devices[i],
995                                        &comm,
996                                        guard.clone(),
997                                    )?,
998                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
999                                        Cow::from(artifact),
1000                                        &devices[i],
1001                                        &comm,
1002                                        guard.clone(),
1003                                    )?,
1004                                }
1005                            }
1006                        };
1007                        *tensor = deserialized;
1008                    }
1009                    Ok(())
1010                })
1011                .collect::<candle_core::Result<Vec<_>>>()?;
1012        }
1013
1014        let delta = Instant::now().duration_since(t_start).as_secs_f32();
1015        info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1016
1017        Ok(())
1018    }
1019}
1020
1021/// Trait for loading models with ISQ.
1022pub(crate) trait IsqModelLoader {
1023    /// Regex to match layers which will have standard *immediate* ISQ applied.
1024    ///
1025    /// Only called on non-adapter models!
1026    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1027        Ok(Vec::new())
1028    }
1029
1030    /// Regex to match layers which will have standard MoQE *immediate* ISQ applied.
1031    ///
1032    /// Only called on non-adapter models!
1033    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1034        self.isq_layer_regexes(config)
1035    }
1036
1037    /// Regex to match layers which will have standard ISQ applied.
1038    ///
1039    /// Only called on non-adapter models!
1040    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1041        Ok(Vec::new())
1042    }
1043
1044    /// Regex to match layers which will have standard MoQE ISQ applied.
1045    ///
1046    /// Only called on non-adapter models!
1047    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1048        self.isq_layer_regexes(config)
1049    }
1050}