mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. } => unreachable!(),
41            QuantMethodConfig::Unquantized(l) => Ok(Self {
42                w: l.weight().clone(),
43                b: l.bias().cloned(),
44                stats: None,
45            }),
46        }
47    }
48
49    fn dequantize_w(&self) -> Result<Tensor> {
50        Ok(self.w.clone())
51    }
52
53    fn forward(&self, a: &Tensor) -> Result<Tensor> {
54        // Batch matrix multiplication
55        maybe_init_cublas_lt_wrapper(a.device().clone());
56
57        let w = match *a.dims() {
58            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59            [bsize, _, _] => self.w.broadcast_left(bsize)?,
60            _ => self.w.clone(),
61        };
62
63        if let Some(stats) = &self.stats {
64            stats.process(a)?;
65        }
66
67        if let Some(b) = self.b.as_ref() {
68            let mut tgt_shape = a.dims().to_vec();
69            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72            match a.device().location() {
73                DeviceLocation::Cuda { .. } => {
74                    // Try to use cublaslt, otherwise fallback to gemm
75                    if let (Device::Cuda(_), Some(cublaslt)) =
76                        (a.device(), CUBLASLT_CONTROLLER.get())
77                    {
78                        cublaslt
79                            .batch_matmul(
80                                a,
81                                &w,
82                                Some(&b.t()?.contiguous()?),
83                                None,
84                                Some(1.0),
85                                None,
86                                None,
87                            )?
88                            .t()
89                    } else {
90                        let matmul_result = a.matmul(&w.t()?)?;
91                        matmul_result.broadcast_add(&b)
92                    }
93                }
94                DeviceLocation::Metal { .. } => {
95                    let matmul_result = a.matmul(&w.t()?)?;
96                    matmul_result.broadcast_add(&b)
97                }
98                DeviceLocation::Cpu => {
99                    #[cfg(feature = "accelerate")]
100                    {
101                        let original_dtype = a.dtype();
102                        let a_f32 = a.to_dtype(DType::F32)?;
103                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
104                        let b_f32 = b.to_dtype(DType::F32)?;
105                        let matmul_result = a_f32.matmul(&w_f32)?;
106                        matmul_result
107                            .broadcast_add(&b_f32)?
108                            .to_dtype(original_dtype)
109                    }
110                    #[cfg(not(feature = "accelerate"))]
111                    {
112                        let matmul_result = a.matmul(&w.t()?)?;
113                        matmul_result.broadcast_add(&b)
114                    }
115                }
116            }
117        } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
118            cublaslt
119                .batch_matmul(a, &w, None, None, None, None, None)?
120                .t()
121        } else {
122            MatMul.matmul(a, &w.t()?)
123        }
124    }
125
126    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
127        // Assume only one expert used.
128        let w = self.w.index_select(indices, 0)?;
129
130        a.broadcast_matmul(&w.t()?)
131    }
132
133    fn quantized_act_type(&self) -> Option<DType> {
134        None
135    }
136
137    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
138        Ok(Arc::new(Self {
139            w: (&self.w + delta)?,
140            b: self.b.clone(),
141            stats: self.stats.clone(),
142        }))
143    }
144
145    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
146        (self.w.dtype(), self.w.device().clone())
147    }
148
149    fn apply_isq(
150        self: Arc<Self>,
151        dtype: Option<IsqType>,
152        device: Device,
153        n_quantized: &AtomicUsize,
154        imatrix_weight: Option<Vec<f32>>,
155        guard: QuantizeOntoGuard,
156    ) -> Result<Arc<dyn QuantMethod>> {
157        match dtype {
158            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
159            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
160                let _acquired_quantize_guard = guard.acquire(&device);
161                if imatrix_weight.is_some() {
162                    // TODO just warn?
163                    candle_core::bail!("HQQ does not support imatrix.");
164                }
165
166                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
167                let bits = match dtype.unwrap() {
168                    IsqType::HQQ8 => HqqBits::Eight,
169                    IsqType::HQQ4 => HqqBits::Four,
170                    // IsqType::HQQ3 => HqqBits::Three,
171                    // IsqType::HQQ2 => HqqBits::Two,
172                    // IsqType::HQQ1 => HqqBits::One,
173                    _ => unreachable!(),
174                };
175                let cfg = HqqConfig {
176                    bits,
177                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
178                    axis: HqqAxis::Zero,
179                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
180                    round_zeros: false,
181                    channel_wise: true,
182                };
183                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
184                if let Some(bias) = &self.b {
185                    let bias = bias
186                        .to_device(&device)?
187                        .to_dtype(res.dtype_and_device().0)?;
188                    Ok(Arc::new(res.with_bias(bias)))
189                } else {
190                    Ok(Arc::new(res))
191                }
192            }
193            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
194                let _acquired_quantize_guard = guard.acquire(&device);
195                if imatrix_weight.is_some() {
196                    // TODO just warn?
197                    candle_core::bail!("AFQ does not support imatrix.");
198                }
199
200                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
201                let bits = match dtype.unwrap() {
202                    IsqType::AFQ8 => AfqBits::Eight,
203                    IsqType::AFQ6 => AfqBits::Six,
204                    IsqType::AFQ4 => AfqBits::Four,
205                    IsqType::AFQ3 => AfqBits::Three,
206                    IsqType::AFQ2 => AfqBits::Two,
207                    _ => unreachable!(),
208                };
209
210                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
211                    weight: self.w.to_device(&device)?,
212                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
213                    bits,
214                    group_size: AfqGroupSize::default(),
215                })?))
216            }
217            Some(
218                IsqType::Q2K
219                | IsqType::Q3K
220                | IsqType::Q4K
221                | IsqType::Q4_0
222                | IsqType::Q4_1
223                | IsqType::Q5K
224                | IsqType::Q5_0
225                | IsqType::Q5_1
226                | IsqType::Q6K
227                | IsqType::Q8K
228                | IsqType::Q8_0
229                | IsqType::Q8_1,
230            ) => {
231                let dtype: GgmlDType = dtype.unwrap().try_into()?;
232                let res = if let Some(imatrix_weight) = imatrix_weight {
233                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
234                } else {
235                    generate_isq!(self.w, device, dtype, n_quantized, guard)
236                };
237                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
238                    q_weight: res,
239                    b: self
240                        .b
241                        .as_ref()
242                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
243                })?))
244            }
245            Some(IsqType::F8E4M3) => {
246                let _acquired_quantize_guard = guard.acquire(&device);
247                if imatrix_weight.is_some() {
248                    // TODO just warn?
249                    candle_core::bail!("F8E4M3 does not support imatrix.");
250                }
251
252                let w = self.w.to_device(&device)?;
253                let b = if let Some(b) = &self.b {
254                    Some(b.to_device(&device)?)
255                } else {
256                    None
257                };
258                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
259                    lin: Linear::new(w, b),
260                    dtype: DType::F8E4M3,
261                })?))
262            }
263            None => {
264                let _acquired_quantize_guard = guard.acquire(&device);
265                // Ignore imatrix altogether
266
267                let w = self.w.to_device(&device)?;
268                let b = if let Some(b) = &self.b {
269                    Some(b.to_device(&device)?)
270                } else {
271                    None
272                };
273                Ok(Arc::new(UnquantLinear::new(
274                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
275                )?))
276            }
277        }
278    }
279
280    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
281        Some((self.w.clone(), self.b.clone()))
282    }
283
284    fn begin_track_stats(&mut self) -> Result<()> {
285        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
286        Ok(())
287    }
288
289    fn end_track_stats(&self) -> Result<Tensor> {
290        if let Some(stats) = &self.stats {
291            let imatrix = stats.compute_imatrix()?;
292            stats.clear()?;
293            Ok(imatrix)
294        } else {
295            candle_core::bail!("`{}` does not support tracking stats.", self.name())
296        }
297    }
298}
299
300// Serialization structure:
301//
302// -----------------------
303// UQFF version, u32, little endian
304// -----------------------
305// ISQ type (1 for unquantized), u8, little endian
306// -----------------------
307// Whether bias data is included, u8 boolean
308// -----------------------
309// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
310// -----------------------
311// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
312// -----------------------
313
314impl QuantizedSerde for UnquantLinear {
315    fn isq_serde_supported(&self) -> bool {
316        true
317    }
318    fn name(&self) -> &'static str {
319        "unquant-linear"
320    }
321    fn serialize(&self) -> Result<Cow<[u8]>> {
322        self.serialize_with_bias(self.b.clone())
323    }
324    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
325        let mut buffer = Vec::new();
326
327        // Version is always first!
328
329        buffer.extend(&UQFF_VERSION.to_le_bytes());
330
331        // ISQ type for unquant is 1
332        buffer.push(QuantizedSerdeType::Unquant as u8);
333
334        // Has bias
335        buffer.push(bias.is_some() as u8);
336
337        // Weight
338        serialize_tensor(&mut buffer, &self.w)?;
339
340        if let Some(bias) = &bias {
341            // Bias
342            serialize_tensor(&mut buffer, bias)?;
343        }
344
345        Ok(Cow::from(buffer))
346    }
347
348    fn deserialize(
349        data: Cow<[u8]>,
350        device: &Device,
351        _comm: &Arc<crate::Comm>,
352        guard: QuantizeOntoGuard,
353    ) -> Result<Arc<dyn QuantMethod>>
354    where
355        Self: Sized,
356    {
357        let mut buffer = Cursor::new(data);
358
359        let version = buffer.read_u32::<LittleEndian>()?;
360        if let Err(e) = version_is_compatible(version) {
361            return Err(candle_core::Error::wrap(e));
362        }
363
364        let isq_type = buffer.read_u8()? as usize;
365        if isq_type != QuantizedSerdeType::Unquant as usize {
366            candle_core::bail!(
367                "ISQ type ({isq_type}) doesn't match expected type {}",
368                QuantizedSerdeType::Unquant as usize
369            );
370        }
371
372        let has_bias = buffer.read_u8()? != 0;
373
374        let _acquired_load_guard = guard.acquire(device);
375        let w = deserialize_tensor(&mut buffer, device)?;
376
377        let b = if has_bias {
378            Some(deserialize_tensor(&mut buffer, device)?)
379        } else {
380            None
381        };
382
383        Ok(Arc::new(Self { w, b, stats: None }))
384    }
385    fn deserialize_ext_bias(
386        data: Cow<[u8]>,
387        device: &Device,
388        guard: QuantizeOntoGuard,
389    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
390    where
391        Self: Sized,
392    {
393        let mut buffer = Cursor::new(data);
394
395        let version = buffer.read_u32::<LittleEndian>()?;
396        if let Err(e) = version_is_compatible(version) {
397            return Err(candle_core::Error::wrap(e));
398        }
399
400        let isq_type = buffer.read_u8()? as usize;
401        if isq_type != QuantizedSerdeType::Unquant as usize {
402            candle_core::bail!(
403                "ISQ type ({isq_type}) doesn't match expected type {}",
404                QuantizedSerdeType::Unquant as usize
405            );
406        }
407
408        let has_bias = buffer.read_u8()? != 0;
409
410        let _acquired_load_guard = guard.acquire(device);
411        let w = deserialize_tensor(&mut buffer, device)?;
412
413        let b = if has_bias {
414            Some(deserialize_tensor(&mut buffer, device)?)
415        } else {
416            None
417        };
418
419        Ok((
420            Arc::new(Self {
421                w,
422                b: None,
423                stats: None,
424            }),
425            b,
426        ))
427    }
428}