mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. }
41            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42            QuantMethodConfig::Unquantized(l) => Ok(Self {
43                w: l.weight().clone(),
44                b: l.bias().cloned(),
45                stats: None,
46            }),
47        }
48    }
49
50    fn dequantize_w(&self) -> Result<Tensor> {
51        Ok(self.w.clone())
52    }
53
54    fn forward(&self, a: &Tensor) -> Result<Tensor> {
55        // Batch matrix multiplication
56        maybe_init_cublas_lt_wrapper(a.device().clone());
57
58        let w = match *a.dims() {
59            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
60            [bsize, _, _] => self.w.broadcast_left(bsize)?,
61            _ => self.w.clone(),
62        };
63
64        if let Some(stats) = &self.stats {
65            stats.process(a)?;
66        }
67
68        if let Some(b) = self.b.as_ref() {
69            let mut tgt_shape = a.dims().to_vec();
70            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
71            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
72
73            match a.device().location() {
74                DeviceLocation::Cuda { .. } => {
75                    // Try to use cublaslt, otherwise fallback to gemm
76                    if let (Device::Cuda(_), Some(cublaslt)) =
77                        (a.device(), CUBLASLT_CONTROLLER.get())
78                    {
79                        cublaslt
80                            .batch_matmul(
81                                a,
82                                &w,
83                                Some(&b.t()?.contiguous()?),
84                                None,
85                                Some(1.0),
86                                None,
87                                None,
88                            )?
89                            .t()
90                    } else {
91                        let matmul_result = a.matmul(&w.t()?)?;
92                        matmul_result.broadcast_add(&b)
93                    }
94                }
95                DeviceLocation::Metal { .. } => {
96                    let matmul_result = a.matmul(&w.t()?)?;
97                    matmul_result.broadcast_add(&b)
98                }
99                DeviceLocation::Cpu => {
100                    #[cfg(feature = "accelerate")]
101                    {
102                        let original_dtype = a.dtype();
103                        let a_f32 = a.to_dtype(DType::F32)?;
104                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
105                        let b_f32 = b.to_dtype(DType::F32)?;
106                        let matmul_result = a_f32.matmul(&w_f32)?;
107                        matmul_result
108                            .broadcast_add(&b_f32)?
109                            .to_dtype(original_dtype)
110                    }
111                    #[cfg(not(feature = "accelerate"))]
112                    {
113                        let matmul_result = a.matmul(&w.t()?)?;
114                        matmul_result.broadcast_add(&b)
115                    }
116                }
117            }
118        } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
119            cublaslt
120                .batch_matmul(a, &w, None, None, None, None, None)?
121                .t()
122        } else {
123            MatMul.matmul(a, &w.t()?)
124        }
125    }
126
127    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
128        // Assume only one expert used.
129        let w = self.w.index_select(indices, 0)?;
130
131        a.broadcast_matmul(&w.t()?)
132    }
133
134    fn quantized_act_type(&self) -> Option<DType> {
135        None
136    }
137
138    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139        Ok(Arc::new(Self {
140            w: (&self.w + delta)?,
141            b: self.b.clone(),
142            stats: self.stats.clone(),
143        }))
144    }
145
146    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
147        (self.w.dtype(), self.w.device().clone())
148    }
149
150    fn apply_isq(
151        self: Arc<Self>,
152        dtype: Option<IsqType>,
153        device: Device,
154        n_quantized: &AtomicUsize,
155        imatrix_weight: Option<Vec<f32>>,
156        guard: QuantizeOntoGuard,
157    ) -> Result<Arc<dyn QuantMethod>> {
158        match dtype {
159            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
160            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
161                let _acquired_quantize_guard = guard.acquire(&device);
162                if imatrix_weight.is_some() {
163                    // TODO just warn?
164                    candle_core::bail!("HQQ does not support imatrix.");
165                }
166
167                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
168                let bits = match dtype.unwrap() {
169                    IsqType::HQQ8 => HqqBits::Eight,
170                    IsqType::HQQ4 => HqqBits::Four,
171                    // IsqType::HQQ3 => HqqBits::Three,
172                    // IsqType::HQQ2 => HqqBits::Two,
173                    // IsqType::HQQ1 => HqqBits::One,
174                    _ => unreachable!(),
175                };
176                let cfg = HqqConfig {
177                    bits,
178                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
179                    axis: HqqAxis::Zero,
180                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
181                    round_zeros: false,
182                    channel_wise: true,
183                };
184                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
185                if let Some(bias) = &self.b {
186                    let bias = bias
187                        .to_device(&device)?
188                        .to_dtype(res.dtype_and_device().0)?;
189                    Ok(Arc::new(res.with_bias(bias)))
190                } else {
191                    Ok(Arc::new(res))
192                }
193            }
194            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
195                let _acquired_quantize_guard = guard.acquire(&device);
196                if imatrix_weight.is_some() {
197                    // TODO just warn?
198                    candle_core::bail!("AFQ does not support imatrix.");
199                }
200
201                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202                let bits = match dtype.unwrap() {
203                    IsqType::AFQ8 => AfqBits::Eight,
204                    IsqType::AFQ6 => AfqBits::Six,
205                    IsqType::AFQ4 => AfqBits::Four,
206                    IsqType::AFQ3 => AfqBits::Three,
207                    IsqType::AFQ2 => AfqBits::Two,
208                    _ => unreachable!(),
209                };
210
211                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
212                    weight: self.w.to_device(&device)?,
213                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
214                    bits,
215                    group_size: AfqGroupSize::default(),
216                })?))
217            }
218            Some(
219                IsqType::Q2K
220                | IsqType::Q3K
221                | IsqType::Q4K
222                | IsqType::Q4_0
223                | IsqType::Q4_1
224                | IsqType::Q5K
225                | IsqType::Q5_0
226                | IsqType::Q5_1
227                | IsqType::Q6K
228                | IsqType::Q8K
229                | IsqType::Q8_0
230                | IsqType::Q8_1,
231            ) => {
232                let dtype: GgmlDType = dtype.unwrap().try_into()?;
233                let res = if let Some(imatrix_weight) = imatrix_weight {
234                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
235                } else {
236                    generate_isq!(self.w, device, dtype, n_quantized, guard)
237                };
238                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
239                    q_weight: res,
240                    b: self
241                        .b
242                        .as_ref()
243                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
244                })?))
245            }
246            Some(IsqType::F8E4M3) => {
247                let _acquired_quantize_guard = guard.acquire(&device);
248                if imatrix_weight.is_some() {
249                    // TODO just warn?
250                    candle_core::bail!("F8E4M3 does not support imatrix.");
251                }
252
253                let w = self.w.to_device(&device)?;
254                let b = if let Some(b) = &self.b {
255                    Some(b.to_device(&device)?)
256                } else {
257                    None
258                };
259                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
260                    lin: Linear::new(w, b),
261                    dtype: DType::F8E4M3,
262                })?))
263            }
264            None => {
265                let _acquired_quantize_guard = guard.acquire(&device);
266                // Ignore imatrix altogether
267
268                let w = self.w.to_device(&device)?;
269                let b = if let Some(b) = &self.b {
270                    Some(b.to_device(&device)?)
271                } else {
272                    None
273                };
274                Ok(Arc::new(UnquantLinear::new(
275                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
276                )?))
277            }
278        }
279    }
280
281    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
282        Some((self.w.clone(), self.b.clone()))
283    }
284
285    fn begin_track_stats(&mut self) -> Result<()> {
286        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
287        Ok(())
288    }
289
290    fn end_track_stats(&self) -> Result<Tensor> {
291        if let Some(stats) = &self.stats {
292            let imatrix = stats.compute_imatrix()?;
293            stats.clear()?;
294            Ok(imatrix)
295        } else {
296            candle_core::bail!("`{}` does not support tracking stats.", self.name())
297        }
298    }
299}
300
301// Serialization structure:
302//
303// -----------------------
304// UQFF version, u32, little endian
305// -----------------------
306// ISQ type (1 for unquantized), u8, little endian
307// -----------------------
308// Whether bias data is included, u8 boolean
309// -----------------------
310// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
311// -----------------------
312// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
313// -----------------------
314
315impl QuantizedSerde for UnquantLinear {
316    fn isq_serde_supported(&self) -> bool {
317        true
318    }
319    fn name(&self) -> &'static str {
320        "unquant-linear"
321    }
322    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
323        self.serialize_with_bias(self.b.clone())
324    }
325    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
326        let mut buffer = Vec::new();
327
328        // Version is always first!
329
330        buffer.extend(&UQFF_VERSION.to_le_bytes());
331
332        // ISQ type for unquant is 1
333        buffer.push(QuantizedSerdeType::Unquant as u8);
334
335        // Has bias
336        buffer.push(bias.is_some() as u8);
337
338        // Weight
339        serialize_tensor(&mut buffer, &self.w)?;
340
341        if let Some(bias) = &bias {
342            // Bias
343            serialize_tensor(&mut buffer, bias)?;
344        }
345
346        Ok(Cow::from(buffer))
347    }
348
349    fn deserialize(
350        data: Cow<[u8]>,
351        device: &Device,
352        _comm: &Arc<crate::Comm>,
353        guard: QuantizeOntoGuard,
354    ) -> Result<Arc<dyn QuantMethod>>
355    where
356        Self: Sized,
357    {
358        let mut buffer = Cursor::new(data);
359
360        let version = buffer.read_u32::<LittleEndian>()?;
361        if let Err(e) = version_is_compatible(version) {
362            return Err(candle_core::Error::wrap(e));
363        }
364
365        let isq_type = buffer.read_u8()? as usize;
366        if isq_type != QuantizedSerdeType::Unquant as usize {
367            candle_core::bail!(
368                "ISQ type ({isq_type}) doesn't match expected type {}",
369                QuantizedSerdeType::Unquant as usize
370            );
371        }
372
373        let has_bias = buffer.read_u8()? != 0;
374
375        let _acquired_load_guard = guard.acquire(device);
376        let w = deserialize_tensor(&mut buffer, device)?;
377
378        let b = if has_bias {
379            Some(deserialize_tensor(&mut buffer, device)?)
380        } else {
381            None
382        };
383
384        Ok(Arc::new(Self { w, b, stats: None }))
385    }
386    fn deserialize_ext_bias(
387        data: Cow<[u8]>,
388        device: &Device,
389        guard: QuantizeOntoGuard,
390    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
391    where
392        Self: Sized,
393    {
394        let mut buffer = Cursor::new(data);
395
396        let version = buffer.read_u32::<LittleEndian>()?;
397        if let Err(e) = version_is_compatible(version) {
398            return Err(candle_core::Error::wrap(e));
399        }
400
401        let isq_type = buffer.read_u8()? as usize;
402        if isq_type != QuantizedSerdeType::Unquant as usize {
403            candle_core::bail!(
404                "ISQ type ({isq_type}) doesn't match expected type {}",
405                QuantizedSerdeType::Unquant as usize
406            );
407        }
408
409        let has_bias = buffer.read_u8()? != 0;
410
411        let _acquired_load_guard = guard.acquire(device);
412        let w = deserialize_tensor(&mut buffer, device)?;
413
414        let b = if has_bias {
415            Some(deserialize_tensor(&mut buffer, device)?)
416        } else {
417            None
418        };
419
420        Ok((
421            Arc::new(Self {
422                w,
423                b: None,
424                stats: None,
425            }),
426            b,
427        ))
428    }
429}