mistralrs_core/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11/// Wrapper around a `Cow<'a, [u8]>` buffer that implements
12/// `safetensors::tensor::View`.
13///
14/// *Purpose*: lets us pass raw byte buffers to
15/// `safetensors::serialize_to_file` without cloning them into a `Vec<u8>` or
16/// converting to a higher‑level tensor type.  
17/// We expose the buffer as a 1‑D `u8` tensor of shape `[len]`.
18#[derive(Clone)]
19pub struct CowBytesView<'a> {
20    data: Cow<'a, [u8]>,
21    shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25    /// Convenience constructor.
26    pub fn new(data: Cow<'a, [u8]>) -> Self {
27        let len = data.len();
28        Self { data, shape: [len] }
29    }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33    fn dtype(&self) -> safetensors::tensor::Dtype {
34        // Serialize as raw bytes
35        safetensors::tensor::Dtype::U8
36    }
37
38    fn shape(&self) -> &[usize] {
39        &self.shape
40    }
41
42    fn data(&self) -> Cow<'_, [u8]> {
43        assert!(matches!(self.data, Cow::Borrowed(_)));
44        // Cloning a `Cow` is cheap (only clones the enum, not the data).
45        self.data.clone()
46    }
47
48    fn data_len(&self) -> usize {
49        self.data.len()
50    }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59    HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60    ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69    device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70    utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74// 10 GB max per file
75const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
76pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
77
78/// Parse ISQ value.
79///
80/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
81/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
82///
83/// One of:
84/// - `Q4_0`
85/// - `Q4_1`
86/// - `Q5_0`
87/// - `Q5_1`
88/// - `Q8_0`
89/// - `Q8_1`
90/// - `Q2K`
91/// - `Q3K`
92/// - `Q4K`
93/// - `Q5K`
94/// - `Q6K`
95/// - `Q8K`
96/// - `HQQ1`
97/// - `HQQ2`
98/// - `HQQ3`
99/// - `HQQ4`
100/// - `HQQ8`
101/// - `AFQ2`
102/// - `AFQ3`
103/// - `AFQ4`
104/// - `AFQ6`
105/// - `AFQ8`
106pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
107    let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
108    let tp = match s.to_lowercase().as_str() {
109        "2" if is_metal => IsqType::AFQ2,
110        "2" if !is_metal => IsqType::Q2K,
111        "3" if is_metal => IsqType::AFQ3,
112        "3" if !is_metal => IsqType::Q3K,
113        "4" if is_metal => IsqType::AFQ4,
114        "4" if !is_metal => IsqType::Q4K,
115        "5" => IsqType::Q5K,
116        "6" if is_metal => IsqType::AFQ6,
117        "6" if !is_metal => IsqType::Q6K,
118        "8" if is_metal => IsqType::AFQ8,
119        "8" if !is_metal => IsqType::Q8_0,
120        "q4_0" => IsqType::Q4_0,
121        "q4_1" => IsqType::Q4_1,
122        "q5_0" => IsqType::Q5_0,
123        "q5_1" => IsqType::Q5_1,
124        "q8_0" => IsqType::Q8_0,
125        "q8_1" => IsqType::Q8_1,
126        "q2k" => IsqType::Q2K,
127        "q3k" => IsqType::Q3K,
128        "q4k" => IsqType::Q4K,
129        "q5k" => IsqType::Q5K,
130        "q6k" => IsqType::Q6K,
131        "q8k" => IsqType::Q8K,
132        "hqq8" => IsqType::HQQ8,
133        "hqq4" => IsqType::HQQ4,
134        "fp8" => IsqType::F8E4M3,
135        "afq8" => IsqType::AFQ8,
136        "afq6" => IsqType::AFQ6,
137        "afq4" => IsqType::AFQ4,
138        "afq3" => IsqType::AFQ3,
139        "afq2" => IsqType::AFQ2,
140        // "hqq3" => IsqType::HQQ3,
141        // "hqq2" => IsqType::HQQ2,
142        // "hqq1" => IsqType::HQQ1,
143        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
144    };
145    #[cfg(feature = "cuda")]
146    {
147        if !matches!(
148            tp,
149            IsqType::Q4_0
150                | IsqType::Q4_1
151                | IsqType::Q5_0
152                | IsqType::Q5_1
153                | IsqType::Q8_0
154                | IsqType::Q2K
155                | IsqType::Q3K
156                | IsqType::Q4K
157                | IsqType::Q5K
158                | IsqType::Q6K
159                | IsqType::HQQ8
160                | IsqType::HQQ4
161                | IsqType::F8E4M3 // | IsqType::HQQ3
162                                  // | IsqType::HQQ2
163                                  // | IsqType::HQQ1
164        ) {
165            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
166        }
167    }
168    Ok(tp)
169}
170
171#[derive(Clone, Debug, Copy, Default, Deserialize)]
172pub enum IsqOrganization {
173    #[default]
174    #[serde(rename = "default")]
175    Default,
176    /// Only quantize MoE experts, if applicable. The enables MoQE.
177    /// <https://arxiv.org/abs/2310.02410>
178    #[serde(rename = "moqe")]
179    MoeExpertsOnly,
180}
181
182impl FromStr for IsqOrganization {
183    type Err = String;
184    fn from_str(s: &str) -> Result<Self, Self::Err> {
185        match s {
186            "default" => Ok(Self::Default),
187            "moqe" => Ok(Self::MoeExpertsOnly),
188            other => Err(format!(
189                "Expected ISQ organization `default` or `moqe`, got `{other}`"
190            )),
191        }
192    }
193}
194
195pub struct UqffFullSer<'a> {
196    pub tokenizer: &'a Tokenizer,
197    pub template_filename: &'a Option<PathBuf>,
198    pub modules: Option<&'a String>,
199    pub module_paths: Option<&'a [EmbeddingModulePaths]>,
200    pub generation_config: Option<&'a PathBuf>,
201    pub config: String,
202    pub processor_filename: &'a Option<PathBuf>,
203    pub preprocessor_filename: &'a Option<PathBuf>,
204}
205
206#[derive(Debug, Clone, Copy)]
207pub enum ImatrixDataSource<'a> {
208    File(&'a PathBuf),
209    Collected,
210}
211
212pub trait IsqModel {
213    /// Corresponds to `IsqOrganization::Default`
214    #[allow(clippy::type_complexity)]
215    fn get_layers(
216        &mut self,
217    ) -> (
218        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
219        &dyn DeviceMapper,
220    );
221
222    /// This is used for imatrix generation internally. Begin stats tracking.
223    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
224        let layers = self
225            .get_layers()
226            .0
227            .into_iter()
228            .map(|(layer, _)| layer)
229            .collect::<Vec<_>>();
230        for layer in layers {
231            Arc::get_mut(layer).unwrap().begin_track_stats()?;
232        }
233        Ok(())
234    }
235
236    /// End stats tracking and return the imatrix data
237    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
238        let layers = self
239            .get_layers()
240            .0
241            .into_iter()
242            .enumerate()
243            .map(|(i, (layer, _))| (i, layer))
244            .collect::<Vec<_>>();
245        let mut data = HashMap::new();
246        for (i, layer) in layers {
247            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
248        }
249        Ok(CollectedImatrixData(data))
250    }
251
252    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
253    /// https://arxiv.org/abs/2310.02410
254    #[allow(clippy::type_complexity)]
255    fn get_layers_moe_experts_only(
256        &mut self,
257    ) -> (
258        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
259        &dyn DeviceMapper,
260    ) {
261        self.get_layers()
262    }
263
264    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
265    /// This is used for imatrix generation internally. Begin stats tracking.
266    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
267        let layers = self
268            .get_layers()
269            .0
270            .into_iter()
271            .map(|(layer, _)| layer)
272            .collect::<Vec<_>>();
273        for layer in layers {
274            Arc::get_mut(layer).unwrap().begin_track_stats()?;
275        }
276        Ok(())
277    }
278
279    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
280    /// End stats tracking and return the imatrix data
281    fn extract_imatrix_data_moe_experts_only(
282        &mut self,
283    ) -> candle_core::Result<CollectedImatrixData> {
284        let layers = self
285            .get_layers()
286            .0
287            .into_iter()
288            .enumerate()
289            .map(|(i, (layer, _))| (i, layer))
290            .collect::<Vec<_>>();
291        let mut data = HashMap::new();
292        for (i, layer) in layers {
293            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
294        }
295        Ok(CollectedImatrixData(data))
296    }
297
298    /// Corresponding to the specific order the model produces ISQ layers (None means
299    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
300    /// corresponding imatrix weights.
301    ///
302    /// - This is only for loading from a llama.cpp imatrix file.
303    /// - Corresponds to `IsqOrganization::Default`
304    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
305        // TODO: make this required.
306        candle_core::bail!("This model does not support quantizing with an imatrix.");
307    }
308
309    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
310    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
311
312    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
313    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
314        None
315    }
316
317    /// Quantize the model in-situ.
318    ///
319    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
320    /// a full serialization is created.
321    #[allow(clippy::too_many_arguments)]
322    fn quantize(
323        &mut self,
324        dtype: Option<IsqType>,
325        device: Device,
326        topology: Option<&Topology>,
327        silent: bool,
328        imatrix_source: Option<ImatrixDataSource<'_>>,
329        organization: IsqOrganization,
330        apply_quantization: bool,
331        write_artifacts: Option<&PathBuf>,
332        full_ser: UqffFullSer<'_>,
333        multi_progress: Arc<MultiProgress>,
334    ) -> candle_core::Result<()> {
335        {
336            let mut imatrix_source = imatrix_source;
337            let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
338                if apply_quantization {
339                    match imatrix_source.take() {
340                        Some(ImatrixDataSource::File(imatrix)) => {
341                            let ext = imatrix.extension().ok_or(candle_core::Error::msg(
342                                "Expected an extension for the imatrix source file.",
343                            ))?;
344                            if ext == "cimatrix" {
345                                info!(
346                                    "Loading collected imatrix source file: `{}`",
347                                    imatrix.display()
348                                );
349                                let data = CollectedImatrixData::load_imatrix(imatrix)?;
350                                info!(
351                                    "Quantizing with collected imatrix data, {} imatrix weights",
352                                    data.0.iter().filter(|(_, x)| x.is_some()).count()
353                                );
354                                Some(data.0)
355                            } else {
356                                if ext != "imatrix" {
357                                    warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
358                                }
359                                info!(
360                                    "Loading GGUF-format imatrix source file: `{}`",
361                                    imatrix.display()
362                                );
363                                let mut imatrix_data =
364                                    quantized::imatrix_file::load_imatrix(imatrix.clone())?;
365                                let imatrix_mapping = self
366                                    .imatrix_names()?
367                                    .into_iter()
368                                    .enumerate()
369                                    .collect::<HashMap<_, _>>();
370
371                                let layer_to_weight = imatrix_mapping
372                                    .into_iter()
373                                    .map(|(i, name)| {
374                                        if let Some(name) = name {
375                                            (i, Some(imatrix_data.remove(&name).unwrap()))
376                                        } else {
377                                            (i, None)
378                                        }
379                                    })
380                                    .collect::<HashMap<_, _>>();
381                                info!(
382                                    "Quantizing with imatrix file `{}`, {} imatrix weights",
383                                    imatrix.display(),
384                                    layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
385                                );
386                                Some(layer_to_weight)
387                            }
388                        }
389                        Some(ImatrixDataSource::Collected) => {
390                            let data = match organization {
391                                IsqOrganization::Default => self.extract_imatrix_data()?,
392                                IsqOrganization::MoeExpertsOnly => {
393                                    self.extract_imatrix_data_moe_experts_only()?
394                                }
395                            };
396                            // Save the collected imatrix data so users can reuse it
397                            let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
398                            let save_path = format!("collected-{count}.cimatrix");
399                            info!("Saving collected imatrix data to `{save_path}`");
400                            data.save_imatrix(save_path)?;
401                            info!(
402                                "Quantizing with collected imatrix data, {count} imatrix weights"
403                            );
404                            Some(data.0)
405                        }
406                        None => None,
407                    }
408                } else {
409                    if imatrix_source.is_some() {
410                        info!("Imatrix source provided but quantization disabled; ignoring input.");
411                    }
412                    None
413                };
414
415            let (mut tensors, mapper) = match organization {
416                IsqOrganization::Default => self.get_layers(),
417                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
418            };
419
420            let total_tensors = tensors.len();
421
422            if apply_quantization {
423                let imatrix_to_weight: Vec<Option<Vec<f32>>> =
424                    if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
425                        let ordered_keys = imatrix_to_weight
426                            .keys()
427                            .copied()
428                            .sorted()
429                            .collect::<Vec<_>>();
430                        ordered_keys
431                            .into_iter()
432                            .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
433                            .collect()
434                    } else {
435                        vec![None; tensors.len()]
436                    };
437
438                let n_quantized = AtomicUsize::new(0);
439                if let Some(topology) = topology {
440                    let mut dtypes = HashSet::new();
441                    for layer in topology.layers.iter().flatten() {
442                        if let LayerTopology {
443                            isq: Some(isq_dtype),
444                            device: _,
445                        } = layer
446                        {
447                            dtypes.insert(isq_dtype);
448                        }
449                    }
450                    info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
451                } else {
452                    info!(
453                        "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
454                    );
455                }
456                let bar = ProgressBar::new(total_tensors as u64);
457                configure_progress_bar(&bar);
458                bar.set_style(
459                    ProgressStyle::default_bar()
460                        .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
461                        .unwrap()
462                        .progress_chars("#>-"),
463                );
464                multi_progress.add(bar.clone());
465
466                let layers = topology.map(|x| {
467                    x.layers
468                        .iter()
469                        .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
470                        .collect::<Vec<_>>()
471                });
472
473                let mut devices_and_dtypes = Vec::new();
474                for (_, layer_num) in &tensors {
475                    let device = if let Some(ref layers) = layers {
476                        if let Some(layer) = layer_num {
477                            layers
478                                .get(*layer)
479                                .as_ref()
480                                .map(|x| x.1.clone())
481                                .unwrap_or(Some(device.clone()))
482                                .unwrap_or(device.clone())
483                        } else {
484                            device.clone()
485                        }
486                    } else if let Some(layer_num) = layer_num {
487                        mapper
488                            .device_for(*layer_num, false)
489                            .cloned()
490                            .unwrap_or(device.clone())
491                    } else {
492                        device.clone()
493                    };
494                    let dtype = if let Some(ref layers) = layers {
495                        if let Some(layer) = layer_num {
496                            layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
497                        } else {
498                            dtype
499                        }
500                    } else {
501                        dtype
502                    };
503                    devices_and_dtypes.push((device, dtype));
504                }
505
506                let t_start = Instant::now();
507
508                // Get the MINIMUM of the max isq threads the quant method
509                let mut minimum_max_threads = {
510                    let current_rayon_threads = rayon::current_num_threads();
511                    if let Some(dtype) = dtype {
512                        dtype
513                            .get_max_isq_cpu_threads()
514                            .map(usize::from)
515                            .unwrap_or(current_rayon_threads)
516                    } else {
517                        current_rayon_threads
518                    }
519                };
520                if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
521                    minimum_max_threads = 1;
522                }
523
524                if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
525                    // Collected imatrix means that the model is potentially on the gpu already
526                    minimum_max_threads = 1;
527                }
528
529                info!("Applying ISQ on {minimum_max_threads} threads.");
530
531                let pool = rayon::ThreadPoolBuilder::new()
532                    .num_threads(minimum_max_threads)
533                    .build()
534                    .map_err(candle_core::Error::msg)?;
535
536                let guard = QuantizeOntoGuard::new();
537
538                pool.install(|| {
539                    use indicatif::ParallelProgressIterator;
540                    use rayon::iter::{
541                        IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
542                    };
543                    if silent {
544                        tensors
545                            .par_iter_mut()
546                            .zip(devices_and_dtypes)
547                            .zip(imatrix_to_weight)
548                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
549                                **tensor = tensor
550                                    .clone()
551                                    .apply_isq(
552                                        dtype,
553                                        device.clone(),
554                                        &n_quantized,
555                                        imatrix_weight,
556                                        guard.clone(),
557                                    )
558                                    .unwrap();
559                                device.synchronize().unwrap();
560                            });
561                    } else {
562                        tensors
563                            .par_iter_mut()
564                            .zip(devices_and_dtypes)
565                            .zip(imatrix_to_weight)
566                            .progress_with(bar)
567                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
568                                **tensor = tensor
569                                    .clone()
570                                    .apply_isq(
571                                        dtype,
572                                        device.clone(),
573                                        &n_quantized,
574                                        imatrix_weight,
575                                        guard.clone(),
576                                    )
577                                    .unwrap();
578                                device.synchronize().unwrap();
579                            });
580                    }
581                });
582
583                let t_end = Instant::now();
584                info!(
585                    "Finished quantization pass in {:.2}s ({} tensors).",
586                    t_end.duration_since(t_start).as_secs_f32(),
587                    total_tensors
588                );
589            } else if imatrix_source.is_some() {
590                info!(
591                    "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
592                );
593            } else if write_artifacts.is_some() {
594                info!(
595                    "Skipping additional quantization; serializing {total_tensors} existing tensors."
596                );
597            }
598
599            if let Some(serialized) = write_artifacts {
600                info!(
601                    "Serializing {total_tensors} ISQ tensors to `{}`.",
602                    serialized.display()
603                );
604
605                if serialized.extension().is_none_or(|ext| ext != "uqff") {
606                    candle_core::bail!("UQFF output path extension must be `.uqff`",);
607                }
608
609                let bar = ProgressBar::new(total_tensors as u64);
610                configure_progress_bar(&bar);
611                bar.set_style(
612                    ProgressStyle::default_bar()
613                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
614                        .unwrap()
615                        .progress_chars("#>-"),
616                );
617
618                #[cfg(not(feature = "metal"))]
619                let n_threads = 2;
620                #[cfg(feature = "metal")]
621                let n_threads = 1;
622
623                let pool = rayon::ThreadPoolBuilder::new()
624                    .num_threads(n_threads)
625                    .build()
626                    .map_err(candle_core::Error::msg)?;
627
628                let quantized_values = pool.install(|| {
629                    use rayon::iter::IntoParallelRefIterator;
630                    if silent {
631                        tensors
632                            .par_iter()
633                            .enumerate()
634                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
635                            .map(|(i, (layer, _))| {
636                                Ok((
637                                    i.to_string(),
638                                    match layer.serialize()? {
639                                        Cow::Borrowed(_) => unreachable!(),
640                                        Cow::Owned(owned) => owned,
641                                    },
642                                ))
643                            })
644                            .collect::<candle_core::Result<Vec<_>>>()
645                    } else {
646                        tensors
647                            .par_iter()
648                            .enumerate()
649                            .progress_with(bar)
650                            .filter(|(_, (layer, _))| layer.isq_serde_supported())
651                            .map(|(i, (layer, _))| {
652                                Ok((
653                                    i.to_string(),
654                                    match layer.serialize()? {
655                                        Cow::Borrowed(_) => unreachable!(),
656                                        Cow::Owned(owned) => owned,
657                                    },
658                                ))
659                            })
660                            .collect::<candle_core::Result<Vec<_>>>()
661                    }
662                });
663                let quantized_values = quantized_values?;
664
665                let parent = serialized
666                    .parent()
667                    .context("Target UQFF path must have a filename!")?;
668
669                std::fs::create_dir_all(parent)?;
670
671                let file_stem = serialized
672                    .file_stem()
673                    .context("Target UQFF path must have a file stem!")?
674                    .to_string_lossy()
675                    .to_string();
676
677                // Shard quantized values by cumulative byte size, max MAX_UQFF_SIZE_BYTES per file
678                let mut current_chunk = Vec::new();
679                let mut current_bytes: usize = 0;
680                let mut shard_index = 0;
681
682                // Every 10GB, flush the file. Then save any remaining tensors
683                for (name, tensor) in quantized_values.iter() {
684                    let tensor_bytes = tensor.len();
685                    if !current_chunk.is_empty()
686                        && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
687                    {
688                        let mut shard_path = parent.to_path_buf();
689                        shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
690                        info!(
691                            "Writing shard {} to `{}`",
692                            shard_index,
693                            shard_path.display()
694                        );
695                        safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
696                        shard_index += 1;
697                        current_chunk.clear();
698                        current_bytes = 0;
699                    }
700                    current_bytes += tensor_bytes;
701                    current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
702                }
703
704                if !current_chunk.is_empty() {
705                    let mut shard_path = parent.to_path_buf();
706                    shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
707                    info!(
708                        "Writing final shard {} to `{}`",
709                        shard_index,
710                        shard_path.display()
711                    );
712                    safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
713                }
714
715                let residual = match organization {
716                    IsqOrganization::Default => self.residual_tensors(),
717                    IsqOrganization::MoeExpertsOnly => self
718                        .residual_tensors_moe_experts_only()
719                        .unwrap_or(self.residual_tensors()),
720                };
721
722                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
723                let config_out = parent.join("config.json");
724                let modules_out = parent.join("modules.json");
725                let tokenizer_out = parent.join("tokenizer.json");
726                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
727                let chat_template_jinja_out = parent.join("chat_template.jinja");
728                let gen_cfg_out = parent.join("generation_config.json");
729                let processor_out = parent.join("processor_config.json");
730                let preprocessor_out = parent.join("preprocessor_config.json");
731
732                info!(
733                    "Serializing {} residual tensors to `{}`.",
734                    residual.len(),
735                    residual_out.display()
736                );
737
738                safetensors::serialize_to_file(residual, None, &residual_out)?;
739
740                let UqffFullSer {
741                    tokenizer,
742                    template_filename,
743                    modules,
744                    module_paths,
745                    generation_config,
746                    config,
747                    processor_filename,
748                    preprocessor_filename,
749                } = full_ser;
750
751                info!("Serializing configuration to `{}`.", config_out.display());
752
753                std::fs::write(config_out, config)?;
754
755                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
756
757                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
758                    .map_err(candle_core::Error::msg)?;
759
760                if let Some(template_filename) = template_filename {
761                    let template =
762                        std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
763
764                    if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
765                        info!(
766                            "Serializing chat template to `{}`.",
767                            chat_template_jinja_out.display()
768                        );
769                        std::fs::write(&chat_template_jinja_out, template)
770                            .map_err(candle_core::Error::msg)?;
771                    } else {
772                        info!(
773                            "Serializing tokenizer config to `{}`.",
774                            tokenizer_cfg_out.display()
775                        );
776                        std::fs::write(&tokenizer_cfg_out, template)
777                            .map_err(candle_core::Error::msg)?;
778                    }
779                }
780
781                if let Some(generation_config) = generation_config {
782                    info!(
783                        "Serializing generation config to `{}`.",
784                        gen_cfg_out.display()
785                    );
786
787                    let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
788                    std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
789                }
790
791                if let Some(processor_config) = processor_filename {
792                    info!(
793                        "Serializing processor config to `{}`.",
794                        processor_out.display()
795                    );
796
797                    let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
798                    std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
799                }
800
801                if let Some(preprocessor_config) = preprocessor_filename {
802                    info!(
803                        "Serializing preprocessor config to `{}`.",
804                        preprocessor_out.display()
805                    );
806
807                    let cfg =
808                        std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
809                    std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
810                }
811
812                if let Some(modules) = modules {
813                    info!(
814                        "Serializing modules manifest to `{}`.",
815                        modules_out.display()
816                    );
817
818                    std::fs::write(&modules_out, modules).map_err(candle_core::Error::msg)?;
819
820                    if let Some(module_paths) = module_paths {
821                        for module in module_paths {
822                            match module {
823                                EmbeddingModulePaths::Transformer { path }
824                                | EmbeddingModulePaths::Pooling { path, .. }
825                                | EmbeddingModulePaths::Dense { path, .. }
826                                | EmbeddingModulePaths::Normalize { path } => {
827                                    if path.is_empty() {
828                                        continue;
829                                    }
830                                    let module_dir = parent.join(path.as_str());
831                                    std::fs::create_dir_all(&module_dir)
832                                        .map_err(candle_core::Error::msg)?;
833
834                                    match module {
835                                        EmbeddingModulePaths::Pooling { config, .. } => {
836                                            let dest = module_dir.join("config.json");
837                                            if config != &dest {
838                                                std::fs::copy(config, &dest)
839                                                    .map_err(candle_core::Error::msg)?;
840                                            }
841                                        }
842                                        EmbeddingModulePaths::Dense { config, model, .. } => {
843                                            let dest_cfg = module_dir.join("config.json");
844                                            if config != &dest_cfg {
845                                                std::fs::copy(config, &dest_cfg)
846                                                    .map_err(candle_core::Error::msg)?;
847                                            }
848                                            let dest_model = module_dir.join("model.safetensors");
849                                            if model != &dest_model {
850                                                std::fs::copy(model, &dest_model)
851                                                    .map_err(candle_core::Error::msg)?;
852                                            }
853                                        }
854                                        EmbeddingModulePaths::Transformer { .. }
855                                        | EmbeddingModulePaths::Normalize { .. } => {}
856                                    }
857                                }
858                            }
859                        }
860                    }
861                }
862            }
863        }
864        Ok(())
865    }
866
867    fn load_from_artifacts(
868        &mut self,
869        device: Device,
870        topology: Option<&Topology>,
871        silent: bool,
872        artifacts: &[PathBuf],
873    ) -> candle_core::Result<()> {
874        let (tensors, mapper) = self.get_layers();
875        let total_tensors = tensors.len();
876
877        let layers = topology.map(|x| {
878            x.layers
879                .iter()
880                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
881                .collect::<Vec<_>>()
882        });
883
884        let mut devices = Vec::new();
885        let mut comms = Vec::new();
886        for (_, layer_num) in &tensors {
887            let device = if let Some(ref layers) = layers {
888                if let Some(layer) = layer_num {
889                    layers
890                        .get(*layer)
891                        .as_ref()
892                        .map(|x| x.1.clone())
893                        .unwrap_or(Some(device.clone()))
894                        .unwrap_or(device.clone())
895                } else {
896                    device.clone()
897                }
898            } else if let Some(layer_num) = layer_num {
899                mapper
900                    .device_for(*layer_num, false)
901                    .cloned()
902                    .unwrap_or(device.clone())
903            } else {
904                device.clone()
905            };
906            devices.push(device);
907            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
908        }
909
910        let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
911
912        let artifact_isqs = artifacts
913            .tensors()
914            .into_iter()
915            .map(|(name, tensor)| {
916                (
917                    name.parse::<usize>()
918                        .expect("Name should be parseable as usize"),
919                    tensor,
920                )
921            })
922            .collect::<HashMap<_, _>>();
923
924        if artifact_isqs.len() != total_tensors {
925            candle_core::bail!(
926                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
927                artifact_isqs.len(),
928            );
929        }
930
931        let bar = ProgressBar::new(total_tensors as u64);
932        configure_progress_bar(&bar);
933        bar.set_style(
934            ProgressStyle::default_bar()
935                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
936                .unwrap()
937                .progress_chars("#>-"),
938        );
939
940        let t_start = Instant::now();
941
942        let guard = QuantizeOntoGuard::new();
943
944        if silent {
945            (0..tensors.len())
946                .into_par_iter()
947                .zip(tensors)
948                .map(|(i, (tensor, _))| {
949                    if let Some(artifact) = artifact_isqs.get(&i) {
950                        let artifact = artifact.data();
951
952                        let comm = comms[i].clone();
953                        let deserialized = match tensor.is_distributed() {
954                            Some(DistributedKind::ColumnParallel) => {
955                                ColumnParallelLayer::deserialize(
956                                    Cow::from(artifact),
957                                    &devices[i],
958                                    &comm,
959                                    guard.clone(),
960                                )?
961                            }
962                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
963                                Cow::from(artifact),
964                                &devices[i],
965                                &comm,
966                                guard.clone(),
967                            )?,
968                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
969                                Cow::from(artifact),
970                                &devices[i],
971                                &comm,
972                                guard.clone(),
973                            )?,
974                            None => {
975                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
976                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
977                                match QuantizedSerdeType::try_from(isq_type as usize)? {
978                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
979                                        Cow::from(artifact),
980                                        &devices[i],
981                                        &comm,
982                                        guard.clone(),
983                                    )?,
984                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
985                                        Cow::from(artifact),
986                                        &devices[i],
987                                        &comm,
988                                        guard.clone(),
989                                    )?,
990                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
991                                        Cow::from(artifact),
992                                        &devices[i],
993                                        &comm,
994                                        guard.clone(),
995                                    )?,
996                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
997                                        Cow::from(artifact),
998                                        &devices[i],
999                                        &comm,
1000                                        guard.clone(),
1001                                    )?,
1002                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1003                                        Cow::from(artifact),
1004                                        &devices[i],
1005                                        &comm,
1006                                        guard.clone(),
1007                                    )?,
1008                                }
1009                            }
1010                        };
1011                        *tensor = deserialized;
1012                    }
1013                    Ok(())
1014                })
1015                .collect::<candle_core::Result<Vec<_>>>()?;
1016        } else {
1017            (0..tensors.len())
1018                .into_par_iter()
1019                .zip(tensors)
1020                .progress_with(bar)
1021                .map(|(i, (tensor, _))| {
1022                    if let Some(artifact) = artifact_isqs.get(&i) {
1023                        let artifact = artifact.data();
1024
1025                        let comm = comms[i].clone();
1026                        let deserialized = match tensor.is_distributed() {
1027                            Some(DistributedKind::ColumnParallel) => {
1028                                ColumnParallelLayer::deserialize(
1029                                    Cow::from(artifact),
1030                                    &devices[i],
1031                                    &comm,
1032                                    guard.clone(),
1033                                )?
1034                            }
1035                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1036                                Cow::from(artifact),
1037                                &devices[i],
1038                                &comm,
1039                                guard.clone(),
1040                            )?,
1041                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1042                                Cow::from(artifact),
1043                                &devices[i],
1044                                &comm,
1045                                guard.clone(),
1046                            )?,
1047                            None => {
1048                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
1049                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
1050                                match QuantizedSerdeType::try_from(isq_type as usize)? {
1051                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1052                                        Cow::from(artifact),
1053                                        &devices[i],
1054                                        &comm,
1055                                        guard.clone(),
1056                                    )?,
1057                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1058                                        Cow::from(artifact),
1059                                        &devices[i],
1060                                        &comm,
1061                                        guard.clone(),
1062                                    )?,
1063                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1064                                        Cow::from(artifact),
1065                                        &devices[i],
1066                                        &comm,
1067                                        guard.clone(),
1068                                    )?,
1069                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1070                                        Cow::from(artifact),
1071                                        &devices[i],
1072                                        &comm,
1073                                        guard.clone(),
1074                                    )?,
1075                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1076                                        Cow::from(artifact),
1077                                        &devices[i],
1078                                        &comm,
1079                                        guard.clone(),
1080                                    )?,
1081                                }
1082                            }
1083                        };
1084                        *tensor = deserialized;
1085                    }
1086                    Ok(())
1087                })
1088                .collect::<candle_core::Result<Vec<_>>>()?;
1089        }
1090
1091        let delta = Instant::now().duration_since(t_start).as_secs_f32();
1092        info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1093
1094        Ok(())
1095    }
1096}
1097
1098/// Trait for loading models with ISQ.
1099pub(crate) trait IsqModelLoader {
1100    /// Regex to match layers which will have standard *immediate* ISQ applied.
1101    ///
1102    /// Only called on non-adapter models!
1103    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1104        Ok(Vec::new())
1105    }
1106
1107    /// Regex to match layers which will have standard MoQE *immediate* ISQ applied.
1108    ///
1109    /// Only called on non-adapter models!
1110    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1111        self.isq_layer_regexes(config)
1112    }
1113
1114    /// Regex to match layers which will have standard ISQ applied.
1115    ///
1116    /// Only called on non-adapter models!
1117    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1118        Ok(Vec::new())
1119    }
1120
1121    /// Regex to match layers which will have standard MoQE ISQ applied.
1122    ///
1123    /// Only called on non-adapter models!
1124    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1125        self.isq_layer_regexes(config)
1126    }
1127}