mistralrs_quant/gguf/
mod.rs

1use std::{
2    borrow::Cow,
3    io::{Cursor, Read},
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9    quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10    DType, Device, Result, Tensor,
11};
12use candle_nn::{Linear, Module};
13
14use crate::{
15    generate_isq, generate_isq_imatrix,
16    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18    UnquantLinear,
19};
20
21#[derive(Debug)]
22pub struct GgufMatMul {
23    pub(crate) w: QMatMul,
24    pub(crate) b: Option<Tensor>,
25}
26
27impl QuantMethod for GgufMatMul {
28    fn new(method: QuantMethodConfig) -> Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
34                w: QMatMul::from_arc(q_weight)?,
35                b,
36            }),
37            QuantMethodConfig::GptqAwq { .. }
38            | QuantMethodConfig::Unquantized(_)
39            | QuantMethodConfig::Hqq { .. }
40            | QuantMethodConfig::Dummy
41            | QuantMethodConfig::FP8 { .. }
42            | QuantMethodConfig::Bnb { .. }
43            | QuantMethodConfig::BlockwiseFP8 { .. }
44            | QuantMethodConfig::Afq { .. }
45            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
46        }
47    }
48
49    fn dequantize_w(&self) -> Result<Tensor> {
50        self.w.dequantize_f16()?.to_dtype(DType::F32)
51    }
52
53    fn forward(&self, a: &Tensor) -> Result<Tensor> {
54        let x = self.w.forward(a)?;
55        if let Some(ref b) = self.b {
56            x.broadcast_add(b)
57        } else {
58            Ok(x)
59        }
60    }
61
62    /// Compute matmul of `self` and `a`. `self` should contain the weights.
63    ///
64    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
65    /// then the indices are (n_tokens, n_experts).
66    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
67        // Dequantize matmul always.
68        // TODO: add a specific kernel?
69        let weight = self.dequantize_w()?;
70        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
71        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
72            weight,
73            self.b.clone(),
74        )))?;
75        unquant.gather_forward(x, indices)
76    }
77
78    fn quantized_act_type(&self) -> Option<DType> {
79        Some(DType::F32)
80    }
81
82    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
83        match self {
84            Self {
85                w: QMatMul::Tensor(w),
86                b,
87            } => Ok(Arc::new(Self {
88                w: QMatMul::Tensor((w + delta)?),
89                b: b.clone(),
90            })),
91            Self {
92                w: QMatMul::TensorF16(w),
93                b,
94            } => Ok(Arc::new(Self {
95                w: QMatMul::TensorF16((w + delta)?),
96                b: b.clone(),
97            })),
98            Self {
99                w: QMatMul::QTensor(w),
100                b,
101            } => {
102                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
103                let w = QMatMul::QTensor(std::sync::Arc::new(
104                    candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
105                ));
106                Ok(Arc::new(Self { w, b: b.clone() }))
107            }
108        }
109    }
110
111    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
112        match &self.w {
113            QMatMul::QTensor(q) => (DType::F32, q.device()),
114            QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
115        }
116    }
117
118    fn apply_isq(
119        self: Arc<Self>,
120        dtype: Option<IsqType>,
121        device: Device,
122        n_quantized: &AtomicUsize,
123        imatrix_weight: Option<Vec<f32>>,
124        guard: QuantizeOntoGuard,
125    ) -> Result<Arc<dyn QuantMethod>> {
126        if let Some(dtype) = dtype {
127            let t = match &self.w {
128                QMatMul::QTensor(q) => q.dequantize(&q.device())?,
129                QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
130            };
131            let dtype = dtype.try_into()?;
132            let res = if let Some(imatrix_weight) = imatrix_weight {
133                generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
134            } else {
135                generate_isq!(t, device, dtype, n_quantized, guard)
136            };
137            Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
138                q_weight: res,
139                b: self.b.clone(),
140            })?))
141        } else {
142            let w = match &self.w {
143                QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
144                    &q.dequantize(&device)?,
145                    q.dtype(),
146                )?)),
147                QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
148                QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
149            };
150            let b = if let Some(b) = &self.b {
151                Some(b.to_device(&device)?)
152            } else {
153                None
154            };
155            Ok(Arc::new(GgufMatMul { w, b }))
156        }
157    }
158}
159
160// Serialization structure:
161//
162// -----------------------
163// UQFF version, u32, little endian
164// -----------------------
165// ISQ type (0 for GGUF), u8, little endian
166// -----------------------
167// Tensor data length in bytes, u32, little endian
168// -----------------------
169// Whether bias data is included, u8 boolean
170// -----------------------
171// Quantized dtype, u32, little endian
172// -----------------------
173// Num shape dims, u32, little endian
174// -----------------------
175// ...
176// Array (in original order): quantized weight shape dims, u32, little endian
177// ...
178// -----------------------
179// ...
180// Array: quantized weight data, u8s
181// ...
182// -----------------------
183// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
184// -----------------------
185
186impl QuantizedSerde for GgufMatMul {
187    fn isq_serde_supported(&self) -> bool {
188        true
189    }
190    fn name(&self) -> &'static str {
191        "gguf"
192    }
193    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
194        self.serialize_with_bias(self.b.clone())
195    }
196    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
197        let mut buffer = match &self.w {
198            QMatMul::QTensor(qw) => {
199                let w = qw.data()?.to_vec();
200                let w_shape = qw.shape().dims();
201                let dtype: u32 = match qw.dtype() {
202                    GgmlDType::F32 => 0,
203                    GgmlDType::F16 => 1,
204                    GgmlDType::Q4_0 => 2,
205                    GgmlDType::Q4_1 => 3,
206                    GgmlDType::Q5_0 => 6,
207                    GgmlDType::Q5_1 => 7,
208                    GgmlDType::Q8_0 => 8,
209                    GgmlDType::Q8_1 => 9,
210                    GgmlDType::Q2K => 10,
211                    GgmlDType::Q3K => 11,
212                    GgmlDType::Q4K => 12,
213                    GgmlDType::Q5K => 13,
214                    GgmlDType::Q6K => 14,
215                    GgmlDType::Q8K => 15,
216                    // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
217                    GgmlDType::BF16 => 30,
218                };
219
220                let mut buffer = Vec::new();
221
222                // Version is always first!
223                buffer.extend(&UQFF_VERSION.to_le_bytes());
224
225                // ISQ type for GGUF is 0
226                buffer.push(QuantizedSerdeType::Gguf as u8);
227
228                // Length
229                buffer.extend(&(w.len() as u32).to_le_bytes());
230
231                // Has bias
232                buffer.push(bias.is_some() as u8);
233
234                // Dtype (u32)
235                buffer.extend(&dtype.to_le_bytes());
236
237                // Shape
238                buffer.extend((w_shape.len() as u32).to_le_bytes());
239                for dim in w_shape {
240                    buffer.extend((*dim as u32).to_le_bytes());
241                }
242
243                // Quantized W Vec<u8> (just append it)
244                buffer.extend(&w);
245
246                buffer
247            }
248            QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
249                candle_core::bail!("Cannot serialize non-quantized")
250            }
251        };
252
253        if let Some(b) = bias.as_ref() {
254            serialize_tensor(&mut buffer, b)?;
255        }
256
257        Ok(Cow::from(buffer))
258    }
259
260    fn deserialize(
261        data: Cow<[u8]>,
262        device: &Device,
263        _comm: &Arc<crate::Comm>,
264        guard: QuantizeOntoGuard,
265    ) -> Result<Arc<dyn QuantMethod>> {
266        let mut buffer = Cursor::new(data);
267
268        let version = buffer.read_u32::<LittleEndian>()?;
269        if let Err(e) = version_is_compatible(version) {
270            return Err(candle_core::Error::wrap(e));
271        }
272
273        let isq_type = buffer.read_u8()? as usize;
274        if isq_type != QuantizedSerdeType::Gguf as usize {
275            candle_core::bail!(
276                "ISQ type ({isq_type}) doesn't match expected type {}",
277                QuantizedSerdeType::Gguf as usize
278            );
279        }
280
281        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
282
283        let has_bias = buffer.read_u8()? != 0;
284
285        // TODO: keep this in sync with get_isq_type_from_uqff!
286        let dtype = buffer.read_u32::<LittleEndian>()?;
287        let dtype = match dtype {
288            0 => GgmlDType::F32,
289            1 => GgmlDType::F16,
290            2 => GgmlDType::Q4_0,
291            3 => GgmlDType::Q4_1,
292            6 => GgmlDType::Q5_0,
293            7 => GgmlDType::Q5_1,
294            8 => GgmlDType::Q8_0,
295            9 => GgmlDType::Q8_1,
296            10 => GgmlDType::Q2K,
297            11 => GgmlDType::Q3K,
298            12 => GgmlDType::Q4K,
299            13 => GgmlDType::Q5K,
300            14 => GgmlDType::Q6K,
301            15 => GgmlDType::Q8K,
302            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
303            30 => GgmlDType::BF16,
304            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
305        };
306
307        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
308
309        let mut dims = Vec::with_capacity(n_dims);
310        for _ in 0..n_dims {
311            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
312        }
313
314        let mut tensor_data = vec![0; data_len];
315        buffer.read_exact(&mut tensor_data)?;
316
317        let _acquired_load_guard = guard.acquire(device);
318        // If we have bias
319        let b = if has_bias {
320            Some(deserialize_tensor(&mut buffer, device)?)
321        } else {
322            None
323        };
324
325        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
326        Ok(Arc::new(Self {
327            w: QMatMul::QTensor(w.into()),
328            b,
329        }))
330    }
331    fn deserialize_ext_bias(
332        data: Cow<[u8]>,
333        device: &Device,
334        guard: QuantizeOntoGuard,
335    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
336        let mut buffer = Cursor::new(data);
337
338        let version = buffer.read_u32::<LittleEndian>()?;
339        if let Err(e) = version_is_compatible(version) {
340            return Err(candle_core::Error::wrap(e));
341        }
342
343        let isq_type = buffer.read_u8()? as usize;
344        if isq_type != QuantizedSerdeType::Gguf as usize {
345            candle_core::bail!(
346                "ISQ type ({isq_type}) doesn't match expected type {}",
347                QuantizedSerdeType::Gguf as usize
348            );
349        }
350
351        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
352
353        let has_bias = buffer.read_u8()? != 0;
354
355        // TODO: keep this in sync with get_isq_type_from_uqff!
356        let dtype = buffer.read_u32::<LittleEndian>()?;
357        let dtype = match dtype {
358            0 => GgmlDType::F32,
359            1 => GgmlDType::F16,
360            2 => GgmlDType::Q4_0,
361            3 => GgmlDType::Q4_1,
362            6 => GgmlDType::Q5_0,
363            7 => GgmlDType::Q5_1,
364            8 => GgmlDType::Q8_0,
365            9 => GgmlDType::Q8_1,
366            10 => GgmlDType::Q2K,
367            11 => GgmlDType::Q3K,
368            12 => GgmlDType::Q4K,
369            13 => GgmlDType::Q5K,
370            14 => GgmlDType::Q6K,
371            15 => GgmlDType::Q8K,
372            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
373            30 => GgmlDType::BF16,
374            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
375        };
376
377        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
378
379        let mut dims = Vec::with_capacity(n_dims);
380        for _ in 0..n_dims {
381            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
382        }
383
384        let mut tensor_data = vec![0; data_len];
385        buffer.read_exact(&mut tensor_data)?;
386
387        let _acquired_load_guard = guard.acquire(device);
388        // If we have bias
389        let b = if has_bias {
390            Some(deserialize_tensor(&mut buffer, device)?)
391        } else {
392            None
393        };
394
395        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
396        Ok((
397            Arc::new(Self {
398                w: QMatMul::QTensor(w.into()),
399                b: None,
400            }),
401            b,
402        ))
403    }
404}
405
406impl GgufMatMul {
407    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
408        let mut buffer = Cursor::new(data);
409
410        let version = buffer.read_u32::<LittleEndian>()?;
411        if let Err(e) = version_is_compatible(version) {
412            return Err(candle_core::Error::wrap(e));
413        }
414
415        let isq_type = buffer.read_u8()? as usize;
416        if isq_type != QuantizedSerdeType::Gguf as usize {
417            candle_core::bail!(
418                "ISQ type ({isq_type}) doesn't match expected type {}",
419                QuantizedSerdeType::Gguf as usize
420            );
421        }
422
423        let _ = buffer.read_u32::<LittleEndian>()? as usize;
424
425        let _ = buffer.read_u8()? != 0;
426
427        let dtype = buffer.read_u32::<LittleEndian>()?;
428        let dtype = match dtype {
429            0 => GgmlDType::F32,
430            1 => GgmlDType::F16,
431            2 => GgmlDType::Q4_0,
432            3 => GgmlDType::Q4_1,
433            6 => GgmlDType::Q5_0,
434            7 => GgmlDType::Q5_1,
435            8 => GgmlDType::Q8_0,
436            9 => GgmlDType::Q8_1,
437            10 => GgmlDType::Q2K,
438            11 => GgmlDType::Q3K,
439            12 => GgmlDType::Q4K,
440            13 => GgmlDType::Q5K,
441            14 => GgmlDType::Q6K,
442            15 => GgmlDType::Q8K,
443            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
444            30 => GgmlDType::BF16,
445            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
446        };
447
448        IsqType::try_from(dtype)
449    }
450}