mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_HANDLE},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::Gptq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. } => unreachable!(),
41            QuantMethodConfig::Unquantized(l) => Ok(Self {
42                w: l.weight().clone(),
43                b: l.bias().cloned(),
44                stats: None,
45            }),
46        }
47    }
48
49    fn dequantize_w(&self) -> Result<Tensor> {
50        Ok(self.w.clone())
51    }
52
53    fn forward(&self, a: &Tensor) -> Result<Tensor> {
54        // Batch matrix multiplication
55        maybe_init_cublas_lt_wrapper(a.device().clone());
56
57        let w = match *a.dims() {
58            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59            [bsize, _, _] => self.w.broadcast_left(bsize)?,
60            _ => self.w.clone(),
61        };
62
63        if let Some(stats) = &self.stats {
64            stats.process(a)?;
65        }
66
67        if let Some(b) = self.b.as_ref() {
68            let mut tgt_shape = a.dims().to_vec();
69            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72            match a.device().location() {
73                DeviceLocation::Cuda { .. } => {
74                    // Try to use cublaslt, otherwise fallback to gemm
75                    if let (Device::Cuda(_), Some(cublaslt)) =
76                        (a.device(), *CUBLASLT_HANDLE.lock().unwrap())
77                    {
78                        cublaslt
79                            .batch_matmul(
80                                a,
81                                &w,
82                                Some(&b.t()?.contiguous()?),
83                                None,
84                                Some(1.0),
85                                None,
86                                None,
87                            )?
88                            .t()
89                    } else {
90                        let mut out = b.contiguous()?;
91                        a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
92                        Ok(out)
93                    }
94                }
95                DeviceLocation::Metal { .. } => {
96                    let mut out = b.contiguous()?;
97                    a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
98                    Ok(out)
99                }
100                DeviceLocation::Cpu => {
101                    #[cfg(feature = "accelerate")]
102                    {
103                        let original_dtype = a.dtype();
104                        let mut out = b.contiguous()?.to_dtype(DType::F32)?;
105                        a.to_dtype(DType::F32)?.matmul_with_alpha_beta(
106                            &w.t()?.to_dtype(DType::F32)?,
107                            &mut out,
108                            None,
109                        )?;
110                        out.to_dtype(original_dtype)
111                    }
112                    #[cfg(not(feature = "accelerate"))]
113                    {
114                        let mut out = b.contiguous()?;
115                        a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
116                        Ok(out)
117                    }
118                }
119            }
120        } else if let (Device::Cuda(_), Some(cublaslt)) =
121            (a.device(), *CUBLASLT_HANDLE.lock().unwrap())
122        {
123            cublaslt
124                .batch_matmul(a, &w, None, None, None, None, None)?
125                .t()
126        } else {
127            MatMul.matmul(a, &w.t()?)
128        }
129    }
130
131    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
132        // Assume only one expert used.
133        let w = self.w.index_select(indices, 0)?;
134
135        a.broadcast_matmul(&w.t()?)
136    }
137
138    fn quantized_act_type(&self) -> Option<DType> {
139        None
140    }
141
142    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
143        Ok(Arc::new(Self {
144            w: (&self.w + delta)?,
145            b: self.b.clone(),
146            stats: self.stats.clone(),
147        }))
148    }
149
150    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
151        (self.w.dtype(), self.w.device().clone())
152    }
153
154    fn apply_isq(
155        self: Arc<Self>,
156        dtype: Option<IsqType>,
157        device: Device,
158        n_quantized: &AtomicUsize,
159        imatrix_weight: Option<Vec<f32>>,
160        guard: QuantizeOntoGuard,
161    ) -> Result<Arc<dyn QuantMethod>> {
162        match dtype {
163            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
164            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
165                let _acquired_quantize_guard = guard.acquire();
166                if imatrix_weight.is_some() {
167                    // TODO just warn?
168                    candle_core::bail!("HQQ does not support imatrix.");
169                }
170
171                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
172                let bits = match dtype.unwrap() {
173                    IsqType::HQQ8 => HqqBits::Eight,
174                    IsqType::HQQ4 => HqqBits::Four,
175                    // IsqType::HQQ3 => HqqBits::Three,
176                    // IsqType::HQQ2 => HqqBits::Two,
177                    // IsqType::HQQ1 => HqqBits::One,
178                    _ => unreachable!(),
179                };
180                let cfg = HqqConfig {
181                    bits,
182                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
183                    axis: HqqAxis::Zero,
184                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
185                    round_zeros: false,
186                    channel_wise: true,
187                };
188                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
189                if let Some(bias) = &self.b {
190                    let bias = bias
191                        .to_device(&device)?
192                        .to_dtype(res.dtype_and_device().0)?;
193                    Ok(Arc::new(res.with_bias(bias)))
194                } else {
195                    Ok(Arc::new(res))
196                }
197            }
198            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
199                let _acquired_quantize_guard = guard.acquire();
200                if imatrix_weight.is_some() {
201                    // TODO just warn?
202                    candle_core::bail!("AFQ does not support imatrix.");
203                }
204
205                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206                let bits = match dtype.unwrap() {
207                    IsqType::AFQ8 => AfqBits::Eight,
208                    IsqType::AFQ6 => AfqBits::Six,
209                    IsqType::AFQ4 => AfqBits::Four,
210                    IsqType::AFQ3 => AfqBits::Three,
211                    IsqType::AFQ2 => AfqBits::Two,
212                    _ => unreachable!(),
213                };
214
215                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
216                    weight: self.w.to_device(&device)?,
217                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
218                    bits,
219                    group_size: AfqGroupSize::default(),
220                })?))
221            }
222            Some(
223                IsqType::Q2K
224                | IsqType::Q3K
225                | IsqType::Q4K
226                | IsqType::Q4_0
227                | IsqType::Q4_1
228                | IsqType::Q5K
229                | IsqType::Q5_0
230                | IsqType::Q5_1
231                | IsqType::Q6K
232                | IsqType::Q8K
233                | IsqType::Q8_0
234                | IsqType::Q8_1,
235            ) => {
236                let dtype: GgmlDType = dtype.unwrap().try_into()?;
237                let res = if let Some(imatrix_weight) = imatrix_weight {
238                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
239                } else {
240                    generate_isq!(self.w, device, dtype, n_quantized, guard)
241                };
242                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
243                    q_weight: res,
244                    b: self
245                        .b
246                        .as_ref()
247                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
248                })?))
249            }
250            Some(IsqType::F8E4M3) => {
251                let _acquired_quantize_guard = guard.acquire();
252                if imatrix_weight.is_some() {
253                    // TODO just warn?
254                    candle_core::bail!("F8E4M3 does not support imatrix.");
255                }
256
257                let w = self.w.to_device(&device)?;
258                let b = if let Some(b) = &self.b {
259                    Some(b.to_device(&device)?)
260                } else {
261                    None
262                };
263                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
264                    lin: Linear::new(w, b),
265                    dtype: DType::F8E4M3,
266                })?))
267            }
268            None => {
269                let _acquired_quantize_guard = guard.acquire();
270                // Ignore imatrix altogether
271
272                let w = self.w.to_device(&device)?;
273                let b = if let Some(b) = &self.b {
274                    Some(b.to_device(&device)?)
275                } else {
276                    None
277                };
278                Ok(Arc::new(UnquantLinear::new(
279                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
280                )?))
281            }
282        }
283    }
284
285    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
286        Some((self.w.clone(), self.b.clone()))
287    }
288
289    fn begin_track_stats(&mut self) -> Result<()> {
290        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
291        Ok(())
292    }
293
294    fn end_track_stats(&self) -> Result<Tensor> {
295        if let Some(stats) = &self.stats {
296            let imatrix = stats.compute_imatrix()?;
297            stats.clear()?;
298            Ok(imatrix)
299        } else {
300            candle_core::bail!("`{}` does not support tracking stats.", self.name())
301        }
302    }
303}
304
305// Serialization structure:
306//
307// -----------------------
308// UQFF version, u32, little endian
309// -----------------------
310// ISQ type (1 for unquantized), u8, little endian
311// -----------------------
312// Whether bias data is included, u8 boolean
313// -----------------------
314// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
315// -----------------------
316// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
317// -----------------------
318
319impl QuantizedSerde for UnquantLinear {
320    fn isq_serde_supported(&self) -> bool {
321        true
322    }
323    fn name(&self) -> &'static str {
324        "unquant-linear"
325    }
326    fn serialize(&self) -> Result<Cow<[u8]>> {
327        self.serialize_with_bias(self.b.clone())
328    }
329    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
330        let mut buffer = Vec::new();
331
332        // Version is always first!
333
334        buffer.extend(&UQFF_VERSION.to_le_bytes());
335
336        // ISQ type for unquant is 1
337        buffer.push(QuantizedSerdeType::Unquant as u8);
338
339        // Has bias
340        buffer.push(bias.is_some() as u8);
341
342        // Weight
343        serialize_tensor(&mut buffer, &self.w)?;
344
345        if let Some(bias) = &bias {
346            // Bias
347            serialize_tensor(&mut buffer, bias)?;
348        }
349
350        Ok(Cow::from(buffer))
351    }
352
353    fn deserialize(
354        data: Cow<[u8]>,
355        device: &Device,
356        _comm: &Arc<crate::Comm>,
357        guard: QuantizeOntoGuard,
358    ) -> Result<Arc<dyn QuantMethod>>
359    where
360        Self: Sized,
361    {
362        let mut buffer = Cursor::new(data);
363
364        let version = buffer.read_u32::<LittleEndian>()?;
365        if let Err(e) = version_is_compatible(version) {
366            return Err(candle_core::Error::wrap(e));
367        }
368
369        let isq_type = buffer.read_u8()? as usize;
370        if isq_type != QuantizedSerdeType::Unquant as usize {
371            candle_core::bail!(
372                "ISQ type ({isq_type}) doesn't match expected type {}",
373                QuantizedSerdeType::Unquant as usize
374            );
375        }
376
377        let has_bias = buffer.read_u8()? != 0;
378
379        let _acquired_load_guard = guard.acquire();
380        let w = deserialize_tensor(&mut buffer, device)?;
381
382        let b = if has_bias {
383            Some(deserialize_tensor(&mut buffer, device)?)
384        } else {
385            None
386        };
387
388        Ok(Arc::new(Self { w, b, stats: None }))
389    }
390    fn deserialize_ext_bias(
391        data: Cow<[u8]>,
392        device: &Device,
393        guard: QuantizeOntoGuard,
394    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
395    where
396        Self: Sized,
397    {
398        let mut buffer = Cursor::new(data);
399
400        let version = buffer.read_u32::<LittleEndian>()?;
401        if let Err(e) = version_is_compatible(version) {
402            return Err(candle_core::Error::wrap(e));
403        }
404
405        let isq_type = buffer.read_u8()? as usize;
406        if isq_type != QuantizedSerdeType::Unquant as usize {
407            candle_core::bail!(
408                "ISQ type ({isq_type}) doesn't match expected type {}",
409                QuantizedSerdeType::Unquant as usize
410            );
411        }
412
413        let has_bias = buffer.read_u8()? != 0;
414
415        let _acquired_load_guard = guard.acquire();
416        let w = deserialize_tensor(&mut buffer, device)?;
417
418        let b = if has_bias {
419            Some(deserialize_tensor(&mut buffer, device)?)
420        } else {
421            None
422        };
423
424        Ok((
425            Arc::new(Self {
426                w,
427                b: None,
428                stats: None,
429            }),
430            b,
431        ))
432    }
433}