mistralrs_quant/gguf/
mod.rs

1use std::{
2    borrow::Cow,
3    io::{Cursor, Read},
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9    quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10    DType, Device, Result, Tensor,
11};
12use candle_nn::Module;
13
14use crate::{
15    generate_isq, generate_isq_imatrix,
16    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct GgufMatMul {
22    pub(crate) w: QMatMul,
23    pub(crate) b: Option<Tensor>,
24}
25
26impl QuantMethod for GgufMatMul {
27    fn new(method: QuantMethodConfig) -> Result<Self>
28    where
29        Self: Sized,
30    {
31        match method {
32            QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
33                w: QMatMul::from_arc(q_weight)?,
34                b,
35            }),
36            QuantMethodConfig::GptqAwq { .. }
37            | QuantMethodConfig::Unquantized(_)
38            | QuantMethodConfig::Hqq { .. }
39            | QuantMethodConfig::Dummy
40            | QuantMethodConfig::FP8 { .. }
41            | QuantMethodConfig::Bnb { .. }
42            | QuantMethodConfig::BlockwiseFP8 { .. }
43            | QuantMethodConfig::Afq { .. }
44            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
45        }
46    }
47
48    fn dequantize_w(&self) -> Result<Tensor> {
49        self.w.dequantize_f16()?.to_dtype(DType::F32)
50    }
51
52    fn forward(&self, a: &Tensor) -> Result<Tensor> {
53        let x = self.w.forward(a)?;
54        if let Some(ref b) = self.b {
55            x.broadcast_add(b)
56        } else {
57            Ok(x)
58        }
59    }
60
61    /// Compute matmul of `self` and `a`. `self` should contain the weights.
62    ///
63    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
64    /// then the indices are (n_tokens, n_experts_per_tok).
65    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
66        // Use QMatMul::indexed_moe_forward for efficient indexed matmul on CUDA
67        // Expected shapes:
68        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
69        // - indices: (n_tokens, n_experts_per_tok)
70        // - weights (self): (n_experts, out_features, in_features)
71        let res = self.w.indexed_moe_forward(x, indices)?;
72        if let Some(ref b) = self.b {
73            res.broadcast_add(b)
74        } else {
75            Ok(res)
76        }
77    }
78
79    fn quantized_act_type(&self) -> Option<DType> {
80        Some(DType::F32)
81    }
82
83    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
84        match self {
85            Self {
86                w: QMatMul::Tensor(w),
87                b,
88            } => Ok(Arc::new(Self {
89                w: QMatMul::Tensor((w + delta)?),
90                b: b.clone(),
91            })),
92            Self {
93                w: QMatMul::TensorF16(w),
94                b,
95            } => Ok(Arc::new(Self {
96                w: QMatMul::TensorF16((w + delta)?),
97                b: b.clone(),
98            })),
99            Self {
100                w: QMatMul::QTensor(w),
101                b,
102            } => {
103                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
104                let w = QMatMul::QTensor(std::sync::Arc::new(
105                    candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
106                ));
107                Ok(Arc::new(Self { w, b: b.clone() }))
108            }
109        }
110    }
111
112    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
113        match &self.w {
114            QMatMul::QTensor(q) => (DType::F32, q.device()),
115            QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
116        }
117    }
118
119    fn apply_isq(
120        self: Arc<Self>,
121        dtype: Option<IsqType>,
122        device: Device,
123        n_quantized: &AtomicUsize,
124        imatrix_weight: Option<Vec<f32>>,
125        guard: QuantizeOntoGuard,
126    ) -> Result<Arc<dyn QuantMethod>> {
127        if let Some(dtype) = dtype {
128            let t = match &self.w {
129                QMatMul::QTensor(q) => q.dequantize(&q.device())?,
130                QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
131            };
132            let dtype = dtype.try_into()?;
133            let res = if let Some(imatrix_weight) = imatrix_weight {
134                generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
135            } else {
136                generate_isq!(t, device, dtype, n_quantized, guard)
137            };
138            Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
139                q_weight: res,
140                b: self.b.clone(),
141            })?))
142        } else {
143            let w = match &self.w {
144                QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
145                    &q.dequantize(&device)?,
146                    q.dtype(),
147                )?)),
148                QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
149                QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
150            };
151            let b = if let Some(b) = &self.b {
152                Some(b.to_device(&device)?)
153            } else {
154                None
155            };
156            Ok(Arc::new(GgufMatMul { w, b }))
157        }
158    }
159}
160
161// Serialization structure:
162//
163// -----------------------
164// UQFF version, u32, little endian
165// -----------------------
166// ISQ type (0 for GGUF), u8, little endian
167// -----------------------
168// Tensor data length in bytes, u32, little endian
169// -----------------------
170// Whether bias data is included, u8 boolean
171// -----------------------
172// Quantized dtype, u32, little endian
173// -----------------------
174// Num shape dims, u32, little endian
175// -----------------------
176// ...
177// Array (in original order): quantized weight shape dims, u32, little endian
178// ...
179// -----------------------
180// ...
181// Array: quantized weight data, u8s
182// ...
183// -----------------------
184// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
185// -----------------------
186
187impl QuantizedSerde for GgufMatMul {
188    fn isq_serde_supported(&self) -> bool {
189        true
190    }
191    fn name(&self) -> &'static str {
192        "gguf"
193    }
194    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
195        self.serialize_with_bias(self.b.clone())
196    }
197    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
198        let mut buffer = match &self.w {
199            QMatMul::QTensor(qw) => {
200                let w = qw.data()?.to_vec();
201                let w_shape = qw.shape().dims();
202                let dtype: u32 = match qw.dtype() {
203                    GgmlDType::F32 => 0,
204                    GgmlDType::F16 => 1,
205                    GgmlDType::Q4_0 => 2,
206                    GgmlDType::Q4_1 => 3,
207                    GgmlDType::Q5_0 => 6,
208                    GgmlDType::Q5_1 => 7,
209                    GgmlDType::Q8_0 => 8,
210                    GgmlDType::Q8_1 => 9,
211                    GgmlDType::Q2K => 10,
212                    GgmlDType::Q3K => 11,
213                    GgmlDType::Q4K => 12,
214                    GgmlDType::Q5K => 13,
215                    GgmlDType::Q6K => 14,
216                    GgmlDType::Q8K => 15,
217                    // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
218                    GgmlDType::BF16 => 30,
219                };
220
221                let mut buffer = Vec::new();
222
223                // Version is always first!
224                buffer.extend(&UQFF_VERSION.to_le_bytes());
225
226                // ISQ type for GGUF is 0
227                buffer.push(QuantizedSerdeType::Gguf as u8);
228
229                // Length
230                buffer.extend(&(w.len() as u32).to_le_bytes());
231
232                // Has bias
233                buffer.push(bias.is_some() as u8);
234
235                // Dtype (u32)
236                buffer.extend(&dtype.to_le_bytes());
237
238                // Shape
239                buffer.extend((w_shape.len() as u32).to_le_bytes());
240                for dim in w_shape {
241                    buffer.extend((*dim as u32).to_le_bytes());
242                }
243
244                // Quantized W Vec<u8> (just append it)
245                buffer.extend(&w);
246
247                buffer
248            }
249            QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
250                candle_core::bail!("Cannot serialize non-quantized")
251            }
252        };
253
254        if let Some(b) = bias.as_ref() {
255            serialize_tensor(&mut buffer, b)?;
256        }
257
258        Ok(Cow::from(buffer))
259    }
260
261    fn deserialize(
262        data: Cow<[u8]>,
263        device: &Device,
264        _comm: &Arc<crate::Comm>,
265        guard: QuantizeOntoGuard,
266    ) -> Result<Arc<dyn QuantMethod>> {
267        let mut buffer = Cursor::new(data);
268
269        let version = buffer.read_u32::<LittleEndian>()?;
270        if let Err(e) = version_is_compatible(version) {
271            return Err(candle_core::Error::wrap(e));
272        }
273
274        let isq_type = buffer.read_u8()? as usize;
275        if isq_type != QuantizedSerdeType::Gguf as usize {
276            candle_core::bail!(
277                "ISQ type ({isq_type}) doesn't match expected type {}",
278                QuantizedSerdeType::Gguf as usize
279            );
280        }
281
282        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
283
284        let has_bias = buffer.read_u8()? != 0;
285
286        // TODO: keep this in sync with get_isq_type_from_uqff!
287        let dtype = buffer.read_u32::<LittleEndian>()?;
288        let dtype = match dtype {
289            0 => GgmlDType::F32,
290            1 => GgmlDType::F16,
291            2 => GgmlDType::Q4_0,
292            3 => GgmlDType::Q4_1,
293            6 => GgmlDType::Q5_0,
294            7 => GgmlDType::Q5_1,
295            8 => GgmlDType::Q8_0,
296            9 => GgmlDType::Q8_1,
297            10 => GgmlDType::Q2K,
298            11 => GgmlDType::Q3K,
299            12 => GgmlDType::Q4K,
300            13 => GgmlDType::Q5K,
301            14 => GgmlDType::Q6K,
302            15 => GgmlDType::Q8K,
303            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
304            30 => GgmlDType::BF16,
305            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
306        };
307
308        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
309
310        let mut dims = Vec::with_capacity(n_dims);
311        for _ in 0..n_dims {
312            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
313        }
314
315        let mut tensor_data = vec![0; data_len];
316        buffer.read_exact(&mut tensor_data)?;
317
318        let _acquired_load_guard = guard.acquire(device);
319        // If we have bias
320        let b = if has_bias {
321            Some(deserialize_tensor(&mut buffer, device)?)
322        } else {
323            None
324        };
325
326        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
327        Ok(Arc::new(Self {
328            w: QMatMul::QTensor(w.into()),
329            b,
330        }))
331    }
332    fn deserialize_ext_bias(
333        data: Cow<[u8]>,
334        device: &Device,
335        guard: QuantizeOntoGuard,
336    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
337        let mut buffer = Cursor::new(data);
338
339        let version = buffer.read_u32::<LittleEndian>()?;
340        if let Err(e) = version_is_compatible(version) {
341            return Err(candle_core::Error::wrap(e));
342        }
343
344        let isq_type = buffer.read_u8()? as usize;
345        if isq_type != QuantizedSerdeType::Gguf as usize {
346            candle_core::bail!(
347                "ISQ type ({isq_type}) doesn't match expected type {}",
348                QuantizedSerdeType::Gguf as usize
349            );
350        }
351
352        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
353
354        let has_bias = buffer.read_u8()? != 0;
355
356        // TODO: keep this in sync with get_isq_type_from_uqff!
357        let dtype = buffer.read_u32::<LittleEndian>()?;
358        let dtype = match dtype {
359            0 => GgmlDType::F32,
360            1 => GgmlDType::F16,
361            2 => GgmlDType::Q4_0,
362            3 => GgmlDType::Q4_1,
363            6 => GgmlDType::Q5_0,
364            7 => GgmlDType::Q5_1,
365            8 => GgmlDType::Q8_0,
366            9 => GgmlDType::Q8_1,
367            10 => GgmlDType::Q2K,
368            11 => GgmlDType::Q3K,
369            12 => GgmlDType::Q4K,
370            13 => GgmlDType::Q5K,
371            14 => GgmlDType::Q6K,
372            15 => GgmlDType::Q8K,
373            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
374            30 => GgmlDType::BF16,
375            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
376        };
377
378        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
379
380        let mut dims = Vec::with_capacity(n_dims);
381        for _ in 0..n_dims {
382            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
383        }
384
385        let mut tensor_data = vec![0; data_len];
386        buffer.read_exact(&mut tensor_data)?;
387
388        let _acquired_load_guard = guard.acquire(device);
389        // If we have bias
390        let b = if has_bias {
391            Some(deserialize_tensor(&mut buffer, device)?)
392        } else {
393            None
394        };
395
396        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
397        Ok((
398            Arc::new(Self {
399                w: QMatMul::QTensor(w.into()),
400                b: None,
401            }),
402            b,
403        ))
404    }
405}
406
407impl GgufMatMul {
408    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
409        let mut buffer = Cursor::new(data);
410
411        let version = buffer.read_u32::<LittleEndian>()?;
412        if let Err(e) = version_is_compatible(version) {
413            return Err(candle_core::Error::wrap(e));
414        }
415
416        let isq_type = buffer.read_u8()? as usize;
417        if isq_type != QuantizedSerdeType::Gguf as usize {
418            candle_core::bail!(
419                "ISQ type ({isq_type}) doesn't match expected type {}",
420                QuantizedSerdeType::Gguf as usize
421            );
422        }
423
424        let _ = buffer.read_u32::<LittleEndian>()? as usize;
425
426        let _ = buffer.read_u8()? != 0;
427
428        let dtype = buffer.read_u32::<LittleEndian>()?;
429        let dtype = match dtype {
430            0 => GgmlDType::F32,
431            1 => GgmlDType::F16,
432            2 => GgmlDType::Q4_0,
433            3 => GgmlDType::Q4_1,
434            6 => GgmlDType::Q5_0,
435            7 => GgmlDType::Q5_1,
436            8 => GgmlDType::Q8_0,
437            9 => GgmlDType::Q8_1,
438            10 => GgmlDType::Q2K,
439            11 => GgmlDType::Q3K,
440            12 => GgmlDType::Q4K,
441            13 => GgmlDType::Q5K,
442            14 => GgmlDType::Q6K,
443            15 => GgmlDType::Q8K,
444            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
445            30 => GgmlDType::BF16,
446            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
447        };
448
449        IsqType::try_from(dtype)
450    }
451}