mistralrs_quant/gguf/
mod.rs

1use std::{
2    borrow::Cow,
3    io::{Cursor, Read},
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9    quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10    DType, Device, Result, Tensor,
11};
12use candle_nn::{Linear, Module};
13
14use crate::{
15    generate_isq, generate_isq_imatrix,
16    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18    UnquantLinear,
19};
20
21#[derive(Debug)]
22pub struct GgufMatMul {
23    pub(crate) w: QMatMul,
24    pub(crate) b: Option<Tensor>,
25}
26
27impl QuantMethod for GgufMatMul {
28    fn new(method: QuantMethodConfig) -> Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
34                w: QMatMul::from_arc(q_weight)?,
35                b,
36            }),
37            QuantMethodConfig::Gptq { .. }
38            | QuantMethodConfig::Unquantized(_)
39            | QuantMethodConfig::Hqq { .. }
40            | QuantMethodConfig::Dummy
41            | QuantMethodConfig::FP8 { .. }
42            | QuantMethodConfig::Bnb { .. }
43            | QuantMethodConfig::BlockwiseFP8 { .. }
44            | QuantMethodConfig::Afq { .. } => unreachable!(),
45        }
46    }
47
48    fn dequantize_w(&self) -> Result<Tensor> {
49        self.w.dequantize_f16()?.to_dtype(DType::F32)
50    }
51
52    fn forward(&self, a: &Tensor) -> Result<Tensor> {
53        let x = self.w.forward(a)?;
54        if let Some(ref b) = self.b {
55            x.broadcast_add(b)
56        } else {
57            Ok(x)
58        }
59    }
60
61    /// Compute matmul of `self` and `a`. `self` should contain the weights.
62    ///
63    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
64    /// then the indices are (n_tokens, n_experts).
65    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
66        // Dequantize matmul always.
67        // TODO: add a specific kernel?
68        let weight = self.dequantize_w()?;
69        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
70        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
71            weight,
72            self.b.clone(),
73        )))?;
74        unquant.gather_forward(x, indices)
75    }
76
77    fn quantized_act_type(&self) -> Option<DType> {
78        Some(DType::F32)
79    }
80
81    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
82        match self {
83            Self {
84                w: QMatMul::Tensor(w),
85                b,
86            } => Ok(Arc::new(Self {
87                w: QMatMul::Tensor((w + delta)?),
88                b: b.clone(),
89            })),
90            Self {
91                w: QMatMul::TensorF16(w),
92                b,
93            } => Ok(Arc::new(Self {
94                w: QMatMul::TensorF16((w + delta)?),
95                b: b.clone(),
96            })),
97            Self {
98                w: QMatMul::QTensor(w),
99                b,
100            } => {
101                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
102                let w = QMatMul::QTensor(std::sync::Arc::new(
103                    candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
104                ));
105                Ok(Arc::new(Self { w, b: b.clone() }))
106            }
107        }
108    }
109
110    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
111        match &self.w {
112            QMatMul::QTensor(q) => (DType::F32, q.device()),
113            QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
114        }
115    }
116
117    fn apply_isq(
118        self: Arc<Self>,
119        dtype: Option<IsqType>,
120        device: Device,
121        n_quantized: &AtomicUsize,
122        imatrix_weight: Option<Vec<f32>>,
123        guard: QuantizeOntoGuard,
124    ) -> Result<Arc<dyn QuantMethod>> {
125        if let Some(dtype) = dtype {
126            let t = match &self.w {
127                QMatMul::QTensor(q) => q.dequantize(&q.device())?,
128                QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
129            };
130            let dtype = dtype.try_into()?;
131            let res = if let Some(imatrix_weight) = imatrix_weight {
132                generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
133            } else {
134                generate_isq!(t, device, dtype, n_quantized, guard)
135            };
136            Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
137                q_weight: res,
138                b: self.b.clone(),
139            })?))
140        } else {
141            let w = match &self.w {
142                QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
143                    &q.dequantize(&device)?,
144                    q.dtype(),
145                )?)),
146                QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
147                QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
148            };
149            let b = if let Some(b) = &self.b {
150                Some(b.to_device(&device)?)
151            } else {
152                None
153            };
154            Ok(Arc::new(GgufMatMul { w, b }))
155        }
156    }
157}
158
159// Serialization structure:
160//
161// -----------------------
162// UQFF version, u32, little endian
163// -----------------------
164// ISQ type (0 for GGUF), u8, little endian
165// -----------------------
166// Tensor data length in bytes, u32, little endian
167// -----------------------
168// Whether bias data is included, u8 boolean
169// -----------------------
170// Quantized dtype, u32, little endian
171// -----------------------
172// Num shape dims, u32, little endian
173// -----------------------
174// ...
175// Array (in original order): quantized weight shape dims, u32, little endian
176// ...
177// -----------------------
178// ...
179// Array: quantized weight data, u8s
180// ...
181// -----------------------
182// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
183// -----------------------
184
185impl QuantizedSerde for GgufMatMul {
186    fn isq_serde_supported(&self) -> bool {
187        true
188    }
189    fn name(&self) -> &'static str {
190        "gguf"
191    }
192    fn serialize(&self) -> Result<Cow<[u8]>> {
193        self.serialize_with_bias(self.b.clone())
194    }
195    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
196        let mut buffer = match &self.w {
197            QMatMul::QTensor(qw) => {
198                let w = qw.data()?.to_vec();
199                let w_shape = qw.shape().dims();
200                let dtype: u32 = match qw.dtype() {
201                    GgmlDType::F32 => 0,
202                    GgmlDType::F16 => 1,
203                    GgmlDType::Q4_0 => 2,
204                    GgmlDType::Q4_1 => 3,
205                    GgmlDType::Q5_0 => 6,
206                    GgmlDType::Q5_1 => 7,
207                    GgmlDType::Q8_0 => 8,
208                    GgmlDType::Q8_1 => 9,
209                    GgmlDType::Q2K => 10,
210                    GgmlDType::Q3K => 11,
211                    GgmlDType::Q4K => 12,
212                    GgmlDType::Q5K => 13,
213                    GgmlDType::Q6K => 14,
214                    GgmlDType::Q8K => 15,
215                    // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
216                    GgmlDType::BF16 => 30,
217                };
218
219                let mut buffer = Vec::new();
220
221                // Version is always first!
222                buffer.extend(&UQFF_VERSION.to_le_bytes());
223
224                // ISQ type for GGUF is 0
225                buffer.push(QuantizedSerdeType::Gguf as u8);
226
227                // Length
228                buffer.extend(&(w.len() as u32).to_le_bytes());
229
230                // Has bias
231                buffer.push(bias.is_some() as u8);
232
233                // Dtype (u32)
234                buffer.extend(&dtype.to_le_bytes());
235
236                // Shape
237                buffer.extend((w_shape.len() as u32).to_le_bytes());
238                for dim in w_shape {
239                    buffer.extend((*dim as u32).to_le_bytes());
240                }
241
242                // Quantized W Vec<u8> (just append it)
243                buffer.extend(&w);
244
245                buffer
246            }
247            QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
248                candle_core::bail!("Cannot serialize non-quantized")
249            }
250        };
251
252        if let Some(b) = bias.as_ref() {
253            serialize_tensor(&mut buffer, b)?;
254        }
255
256        Ok(Cow::from(buffer))
257    }
258
259    fn deserialize(
260        data: Cow<[u8]>,
261        device: &Device,
262        _comm: &Arc<crate::Comm>,
263        guard: QuantizeOntoGuard,
264    ) -> Result<Arc<dyn QuantMethod>> {
265        let mut buffer = Cursor::new(data);
266
267        let version = buffer.read_u32::<LittleEndian>()?;
268        if let Err(e) = version_is_compatible(version) {
269            return Err(candle_core::Error::wrap(e));
270        }
271
272        let isq_type = buffer.read_u8()? as usize;
273        if isq_type != QuantizedSerdeType::Gguf as usize {
274            candle_core::bail!(
275                "ISQ type ({isq_type}) doesn't match expected type {}",
276                QuantizedSerdeType::Gguf as usize
277            );
278        }
279
280        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
281
282        let has_bias = buffer.read_u8()? != 0;
283
284        // TODO: keep this in sync with get_isq_type_from_uqff!
285        let dtype = buffer.read_u32::<LittleEndian>()?;
286        let dtype = match dtype {
287            0 => GgmlDType::F32,
288            1 => GgmlDType::F16,
289            2 => GgmlDType::Q4_0,
290            3 => GgmlDType::Q4_1,
291            6 => GgmlDType::Q5_0,
292            7 => GgmlDType::Q5_1,
293            8 => GgmlDType::Q8_0,
294            9 => GgmlDType::Q8_1,
295            10 => GgmlDType::Q2K,
296            11 => GgmlDType::Q3K,
297            12 => GgmlDType::Q4K,
298            13 => GgmlDType::Q5K,
299            14 => GgmlDType::Q6K,
300            15 => GgmlDType::Q8K,
301            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
302            30 => GgmlDType::BF16,
303            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
304        };
305
306        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
307
308        let mut dims = Vec::with_capacity(n_dims);
309        for _ in 0..n_dims {
310            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
311        }
312
313        let mut tensor_data = vec![0; data_len];
314        buffer.read_exact(&mut tensor_data)?;
315
316        let _acquired_load_guard = guard.acquire();
317        // If we have bias
318        let b = if has_bias {
319            Some(deserialize_tensor(&mut buffer, device)?)
320        } else {
321            None
322        };
323
324        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
325        Ok(Arc::new(Self {
326            w: QMatMul::QTensor(w.into()),
327            b,
328        }))
329    }
330    fn deserialize_ext_bias(
331        data: Cow<[u8]>,
332        device: &Device,
333        guard: QuantizeOntoGuard,
334    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
335        let mut buffer = Cursor::new(data);
336
337        let version = buffer.read_u32::<LittleEndian>()?;
338        if let Err(e) = version_is_compatible(version) {
339            return Err(candle_core::Error::wrap(e));
340        }
341
342        let isq_type = buffer.read_u8()? as usize;
343        if isq_type != QuantizedSerdeType::Gguf as usize {
344            candle_core::bail!(
345                "ISQ type ({isq_type}) doesn't match expected type {}",
346                QuantizedSerdeType::Gguf as usize
347            );
348        }
349
350        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
351
352        let has_bias = buffer.read_u8()? != 0;
353
354        // TODO: keep this in sync with get_isq_type_from_uqff!
355        let dtype = buffer.read_u32::<LittleEndian>()?;
356        let dtype = match dtype {
357            0 => GgmlDType::F32,
358            1 => GgmlDType::F16,
359            2 => GgmlDType::Q4_0,
360            3 => GgmlDType::Q4_1,
361            6 => GgmlDType::Q5_0,
362            7 => GgmlDType::Q5_1,
363            8 => GgmlDType::Q8_0,
364            9 => GgmlDType::Q8_1,
365            10 => GgmlDType::Q2K,
366            11 => GgmlDType::Q3K,
367            12 => GgmlDType::Q4K,
368            13 => GgmlDType::Q5K,
369            14 => GgmlDType::Q6K,
370            15 => GgmlDType::Q8K,
371            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
372            30 => GgmlDType::BF16,
373            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
374        };
375
376        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
377
378        let mut dims = Vec::with_capacity(n_dims);
379        for _ in 0..n_dims {
380            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
381        }
382
383        let mut tensor_data = vec![0; data_len];
384        buffer.read_exact(&mut tensor_data)?;
385
386        let _acquired_load_guard = guard.acquire();
387        // If we have bias
388        let b = if has_bias {
389            Some(deserialize_tensor(&mut buffer, device)?)
390        } else {
391            None
392        };
393
394        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
395        Ok((
396            Arc::new(Self {
397                w: QMatMul::QTensor(w.into()),
398                b: None,
399            }),
400            b,
401        ))
402    }
403}
404
405impl GgufMatMul {
406    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
407        let mut buffer = Cursor::new(data);
408
409        let version = buffer.read_u32::<LittleEndian>()?;
410        if let Err(e) = version_is_compatible(version) {
411            return Err(candle_core::Error::wrap(e));
412        }
413
414        let isq_type = buffer.read_u8()? as usize;
415        if isq_type != QuantizedSerdeType::Gguf as usize {
416            candle_core::bail!(
417                "ISQ type ({isq_type}) doesn't match expected type {}",
418                QuantizedSerdeType::Gguf as usize
419            );
420        }
421
422        let _ = buffer.read_u32::<LittleEndian>()? as usize;
423
424        let _ = buffer.read_u8()? != 0;
425
426        let dtype = buffer.read_u32::<LittleEndian>()?;
427        let dtype = match dtype {
428            0 => GgmlDType::F32,
429            1 => GgmlDType::F16,
430            2 => GgmlDType::Q4_0,
431            3 => GgmlDType::Q4_1,
432            6 => GgmlDType::Q5_0,
433            7 => GgmlDType::Q5_1,
434            8 => GgmlDType::Q8_0,
435            9 => GgmlDType::Q8_1,
436            10 => GgmlDType::Q2K,
437            11 => GgmlDType::Q3K,
438            12 => GgmlDType::Q4K,
439            13 => GgmlDType::Q5K,
440            14 => GgmlDType::Q6K,
441            15 => GgmlDType::Q8K,
442            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
443            30 => GgmlDType::BF16,
444            _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
445        };
446
447        IsqType::try_from(dtype)
448    }
449}