mistralrs_quant/gguf/
mod.rs

1#[cfg(not(feature = "cuda"))]
2mod cpu;
3#[cfg(feature = "cuda")]
4mod cuda;
5#[cfg(feature = "cuda")]
6mod ffi;
7
8use std::{
9    borrow::Cow,
10    io::{Cursor, Read},
11    sync::{atomic::AtomicUsize, Arc},
12};
13
14use byteorder::{LittleEndian, ReadBytesExt};
15use candle_core::{
16    quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
17    DType, Device, Result, Tensor,
18};
19use candle_nn::Module;
20
21use crate::{
22    generate_isq, generate_isq_imatrix,
23    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
24    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
25};
26
27#[derive(Debug)]
28pub struct GgufMatMul {
29    pub(crate) w: QMatMul,
30    pub(crate) b: Option<Tensor>,
31}
32
33impl QuantMethod for GgufMatMul {
34    fn new(method: QuantMethodConfig) -> Result<Self>
35    where
36        Self: Sized,
37    {
38        match method {
39            QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
40                w: QMatMul::from_arc(q_weight)?,
41                b,
42            }),
43            QuantMethodConfig::GptqAwq { .. }
44            | QuantMethodConfig::Unquantized(_)
45            | QuantMethodConfig::Hqq { .. }
46            | QuantMethodConfig::Dummy
47            | QuantMethodConfig::FP8 { .. }
48            | QuantMethodConfig::Bnb { .. }
49            | QuantMethodConfig::BlockwiseFP8 { .. }
50            | QuantMethodConfig::PerTensorFP8 { .. }
51            | QuantMethodConfig::Afq { .. }
52            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
53        }
54    }
55
56    fn dequantize_w(&self) -> Result<Tensor> {
57        self.w.dequantize_f16()?.to_dtype(DType::F32)
58    }
59
60    fn forward(&self, a: &Tensor) -> Result<Tensor> {
61        let x = self.w.forward(a)?;
62        if let Some(ref b) = self.b {
63            x.broadcast_add(b)
64        } else {
65            Ok(x)
66        }
67    }
68
69    /// Compute matmul of `self` and `a`. `self` should contain the weights.
70    ///
71    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
72    /// then the indices are (n_tokens, n_experts_per_tok).
73    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
74        // Use indexed_moe_forward for efficient indexed matmul
75        // Expected shapes:
76        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
77        // - indices: (n_tokens, n_experts_per_tok)
78        // - weights (self): (n_experts, out_features, in_features)
79        #[cfg(feature = "cuda")]
80        let res = cuda::qmatmul_indexed_moe_forward(&self.w, x, indices)?;
81
82        // For CPU and Metal: use dequantize-then-matmul approach
83        #[cfg(not(feature = "cuda"))]
84        let res = cpu::cpu_indexed_moe_forward(&self.w, x, indices)?;
85
86        if let Some(ref b) = self.b {
87            res.broadcast_add(b)
88        } else {
89            Ok(res)
90        }
91    }
92
93    fn quantized_act_type(&self) -> Option<DType> {
94        Some(DType::F32)
95    }
96
97    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
98        match self {
99            Self {
100                w: QMatMul::Tensor(w),
101                b,
102            } => Ok(Arc::new(Self {
103                w: QMatMul::Tensor((w + delta)?),
104                b: b.clone(),
105            })),
106            Self {
107                w: QMatMul::TensorF16(w),
108                b,
109            } => Ok(Arc::new(Self {
110                w: QMatMul::TensorF16((w + delta)?),
111                b: b.clone(),
112            })),
113            Self {
114                w: QMatMul::QTensor(w),
115                b,
116            } => {
117                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
118                let w = QMatMul::QTensor(std::sync::Arc::new(
119                    candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
120                ));
121                Ok(Arc::new(Self { w, b: b.clone() }))
122            }
123        }
124    }
125
126    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
127        match &self.w {
128            QMatMul::QTensor(q) => (DType::F32, q.device()),
129            QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
130        }
131    }
132
133    fn apply_isq(
134        self: Arc<Self>,
135        dtype: Option<IsqType>,
136        device: Device,
137        n_quantized: &AtomicUsize,
138        imatrix_weight: Option<Vec<f32>>,
139        guard: QuantizeOntoGuard,
140    ) -> Result<Arc<dyn QuantMethod>> {
141        if let Some(dtype) = dtype {
142            let t = match &self.w {
143                QMatMul::QTensor(q) => q.dequantize(&q.device())?,
144                QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
145            };
146            let dtype = dtype.try_into()?;
147            let res = if let Some(imatrix_weight) = imatrix_weight {
148                generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
149            } else {
150                generate_isq!(t, device, dtype, n_quantized, guard)
151            };
152            Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
153                q_weight: res,
154                b: self.b.clone(),
155            })?))
156        } else {
157            let w = match &self.w {
158                QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
159                    &q.dequantize(&device)?,
160                    q.dtype(),
161                )?)),
162                QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
163                QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
164            };
165            let b = if let Some(b) = &self.b {
166                Some(b.to_device(&device)?)
167            } else {
168                None
169            };
170            Ok(Arc::new(GgufMatMul { w, b }))
171        }
172    }
173}
174
175// Serialization structure:
176//
177// -----------------------
178// UQFF version, u32, little endian
179// -----------------------
180// ISQ type (0 for GGUF), u8, little endian
181// -----------------------
182// Tensor data length in bytes, u32, little endian
183// -----------------------
184// Whether bias data is included, u8 boolean
185// -----------------------
186// Quantized dtype, u32, little endian
187// -----------------------
188// Num shape dims, u32, little endian
189// -----------------------
190// ...
191// Array (in original order): quantized weight shape dims, u32, little endian
192// ...
193// -----------------------
194// ...
195// Array: quantized weight data, u8s
196// ...
197// -----------------------
198// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
199// -----------------------
200
201impl QuantizedSerde for GgufMatMul {
202    fn isq_serde_supported(&self) -> bool {
203        true
204    }
205    fn name(&self) -> &'static str {
206        "gguf"
207    }
208    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
209        self.serialize_with_bias(self.b.clone())
210    }
211    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
212        let mut buffer = match &self.w {
213            QMatMul::QTensor(qw) => {
214                let w = qw.data()?.to_vec();
215                let w_shape = qw.shape().dims();
216                let dtype: u32 = match qw.dtype() {
217                    GgmlDType::F32 => 0,
218                    GgmlDType::F16 => 1,
219                    GgmlDType::Q4_0 => 2,
220                    GgmlDType::Q4_1 => 3,
221                    GgmlDType::Q5_0 => 6,
222                    GgmlDType::Q5_1 => 7,
223                    GgmlDType::Q8_0 => 8,
224                    GgmlDType::Q8_1 => 9,
225                    GgmlDType::Q2K => 10,
226                    GgmlDType::Q3K => 11,
227                    GgmlDType::Q4K => 12,
228                    GgmlDType::Q5K => 13,
229                    GgmlDType::Q6K => 14,
230                    GgmlDType::Q8K => 15,
231                    // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
232                    GgmlDType::BF16 => 30,
233                };
234
235                let mut buffer = Vec::new();
236
237                // Version is always first!
238                buffer.extend(&UQFF_VERSION.to_le_bytes());
239
240                // ISQ type for GGUF is 0
241                buffer.push(QuantizedSerdeType::Gguf as u8);
242
243                // Length
244                buffer.extend(&(w.len() as u32).to_le_bytes());
245
246                // Has bias
247                buffer.push(bias.is_some() as u8);
248
249                // Dtype (u32)
250                buffer.extend(&dtype.to_le_bytes());
251
252                // Shape
253                buffer.extend((w_shape.len() as u32).to_le_bytes());
254                for dim in w_shape {
255                    buffer.extend((*dim as u32).to_le_bytes());
256                }
257
258                // Quantized W Vec<u8> (just append it)
259                buffer.extend(&w);
260
261                buffer
262            }
263            QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
264                candle_core::bail!("Cannot serialize non-quantized")
265            }
266        };
267
268        if let Some(b) = bias.as_ref() {
269            serialize_tensor(&mut buffer, b)?;
270        }
271
272        Ok(Cow::from(buffer))
273    }
274
275    fn deserialize(
276        data: Cow<[u8]>,
277        device: &Device,
278        _comm: &Arc<crate::Comm>,
279        guard: QuantizeOntoGuard,
280    ) -> Result<Arc<dyn QuantMethod>> {
281        let mut buffer = Cursor::new(data);
282
283        let version = buffer.read_u32::<LittleEndian>()?;
284        if let Err(e) = version_is_compatible(version) {
285            return Err(candle_core::Error::wrap(e));
286        }
287
288        let isq_type = buffer.read_u8()? as usize;
289        if isq_type != QuantizedSerdeType::Gguf as usize {
290            candle_core::bail!(
291                "ISQ type ({isq_type}) doesn't match expected type {}",
292                QuantizedSerdeType::Gguf as usize
293            );
294        }
295
296        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
297
298        let has_bias = buffer.read_u8()? != 0;
299
300        // TODO: keep this in sync with get_isq_type_from_uqff!
301        let dtype = buffer.read_u32::<LittleEndian>()?;
302        let dtype = match dtype {
303            0 => GgmlDType::F32,
304            1 => GgmlDType::F16,
305            2 => GgmlDType::Q4_0,
306            3 => GgmlDType::Q4_1,
307            6 => GgmlDType::Q5_0,
308            7 => GgmlDType::Q5_1,
309            8 => GgmlDType::Q8_0,
310            9 => GgmlDType::Q8_1,
311            10 => GgmlDType::Q2K,
312            11 => GgmlDType::Q3K,
313            12 => GgmlDType::Q4K,
314            13 => GgmlDType::Q5K,
315            14 => GgmlDType::Q6K,
316            15 => GgmlDType::Q8K,
317            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
318            30 => GgmlDType::BF16,
319            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
320        };
321
322        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
323
324        let mut dims = Vec::with_capacity(n_dims);
325        for _ in 0..n_dims {
326            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
327        }
328
329        let mut tensor_data = vec![0; data_len];
330        buffer.read_exact(&mut tensor_data)?;
331
332        let _acquired_load_guard = guard.acquire(device);
333        // If we have bias
334        let b = if has_bias {
335            Some(deserialize_tensor(&mut buffer, device)?)
336        } else {
337            None
338        };
339
340        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
341        Ok(Arc::new(Self {
342            w: QMatMul::QTensor(w.into()),
343            b,
344        }))
345    }
346    fn deserialize_ext_bias(
347        data: Cow<[u8]>,
348        device: &Device,
349        guard: QuantizeOntoGuard,
350    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
351        let mut buffer = Cursor::new(data);
352
353        let version = buffer.read_u32::<LittleEndian>()?;
354        if let Err(e) = version_is_compatible(version) {
355            return Err(candle_core::Error::wrap(e));
356        }
357
358        let isq_type = buffer.read_u8()? as usize;
359        if isq_type != QuantizedSerdeType::Gguf as usize {
360            candle_core::bail!(
361                "ISQ type ({isq_type}) doesn't match expected type {}",
362                QuantizedSerdeType::Gguf as usize
363            );
364        }
365
366        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
367
368        let has_bias = buffer.read_u8()? != 0;
369
370        // TODO: keep this in sync with get_isq_type_from_uqff!
371        let dtype = buffer.read_u32::<LittleEndian>()?;
372        let dtype = match dtype {
373            0 => GgmlDType::F32,
374            1 => GgmlDType::F16,
375            2 => GgmlDType::Q4_0,
376            3 => GgmlDType::Q4_1,
377            6 => GgmlDType::Q5_0,
378            7 => GgmlDType::Q5_1,
379            8 => GgmlDType::Q8_0,
380            9 => GgmlDType::Q8_1,
381            10 => GgmlDType::Q2K,
382            11 => GgmlDType::Q3K,
383            12 => GgmlDType::Q4K,
384            13 => GgmlDType::Q5K,
385            14 => GgmlDType::Q6K,
386            15 => GgmlDType::Q8K,
387            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
388            30 => GgmlDType::BF16,
389            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
390        };
391
392        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
393
394        let mut dims = Vec::with_capacity(n_dims);
395        for _ in 0..n_dims {
396            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
397        }
398
399        let mut tensor_data = vec![0; data_len];
400        buffer.read_exact(&mut tensor_data)?;
401
402        let _acquired_load_guard = guard.acquire(device);
403        // If we have bias
404        let b = if has_bias {
405            Some(deserialize_tensor(&mut buffer, device)?)
406        } else {
407            None
408        };
409
410        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
411        Ok((
412            Arc::new(Self {
413                w: QMatMul::QTensor(w.into()),
414                b: None,
415            }),
416            b,
417        ))
418    }
419}
420
421impl GgufMatMul {
422    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
423        let mut buffer = Cursor::new(data);
424
425        let version = buffer.read_u32::<LittleEndian>()?;
426        if let Err(e) = version_is_compatible(version) {
427            return Err(candle_core::Error::wrap(e));
428        }
429
430        let isq_type = buffer.read_u8()? as usize;
431        if isq_type != QuantizedSerdeType::Gguf as usize {
432            candle_core::bail!(
433                "ISQ type ({isq_type}) doesn't match expected type {}",
434                QuantizedSerdeType::Gguf as usize
435            );
436        }
437
438        let _ = buffer.read_u32::<LittleEndian>()? as usize;
439
440        let _ = buffer.read_u8()? != 0;
441
442        let dtype = buffer.read_u32::<LittleEndian>()?;
443        let dtype = match dtype {
444            0 => GgmlDType::F32,
445            1 => GgmlDType::F16,
446            2 => GgmlDType::Q4_0,
447            3 => GgmlDType::Q4_1,
448            6 => GgmlDType::Q5_0,
449            7 => GgmlDType::Q5_1,
450            8 => GgmlDType::Q8_0,
451            9 => GgmlDType::Q8_1,
452            10 => GgmlDType::Q2K,
453            11 => GgmlDType::Q3K,
454            12 => GgmlDType::Q4K,
455            13 => GgmlDType::Q5K,
456            14 => GgmlDType::Q6K,
457            15 => GgmlDType::Q8K,
458            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
459            30 => GgmlDType::BF16,
460            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
461        };
462
463        IsqType::try_from(dtype)
464    }
465}