mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. } => unreachable!(),
41            QuantMethodConfig::Unquantized(l) => Ok(Self {
42                w: l.weight().clone(),
43                b: l.bias().cloned(),
44                stats: None,
45            }),
46        }
47    }
48
49    fn dequantize_w(&self) -> Result<Tensor> {
50        Ok(self.w.clone())
51    }
52
53    fn forward(&self, a: &Tensor) -> Result<Tensor> {
54        // Batch matrix multiplication
55        maybe_init_cublas_lt_wrapper(a.device().clone());
56
57        let w = match *a.dims() {
58            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59            [bsize, _, _] => self.w.broadcast_left(bsize)?,
60            _ => self.w.clone(),
61        };
62
63        if let Some(stats) = &self.stats {
64            stats.process(a)?;
65        }
66
67        if let Some(b) = self.b.as_ref() {
68            let mut tgt_shape = a.dims().to_vec();
69            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72            match a.device().location() {
73                DeviceLocation::Cuda { .. } => {
74                    // Try to use cublaslt, otherwise fallback to gemm
75                    if let (Device::Cuda(_), Some(cublaslt)) =
76                        (a.device(), CUBLASLT_CONTROLLER.get())
77                    {
78                        cublaslt
79                            .batch_matmul(
80                                a,
81                                &w,
82                                Some(&b.t()?.contiguous()?),
83                                None,
84                                Some(1.0),
85                                None,
86                                None,
87                            )?
88                            .t()
89                    } else {
90                        let mut out = b.contiguous()?;
91                        a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
92                        Ok(out)
93                    }
94                }
95                DeviceLocation::Metal { .. } => {
96                    let mut out = b.contiguous()?;
97                    a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
98                    Ok(out)
99                }
100                DeviceLocation::Cpu => {
101                    #[cfg(feature = "accelerate")]
102                    {
103                        let original_dtype = a.dtype();
104                        let mut out = b.contiguous()?.to_dtype(DType::F32)?;
105                        a.to_dtype(DType::F32)?.matmul_with_alpha_beta(
106                            &w.t()?.to_dtype(DType::F32)?,
107                            &mut out,
108                            None,
109                        )?;
110                        out.to_dtype(original_dtype)
111                    }
112                    #[cfg(not(feature = "accelerate"))]
113                    {
114                        let mut out = b.contiguous()?;
115                        a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
116                        Ok(out)
117                    }
118                }
119            }
120        } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
121            cublaslt
122                .batch_matmul(a, &w, None, None, None, None, None)?
123                .t()
124        } else {
125            MatMul.matmul(a, &w.t()?)
126        }
127    }
128
129    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
130        // Assume only one expert used.
131        let w = self.w.index_select(indices, 0)?;
132
133        a.broadcast_matmul(&w.t()?)
134    }
135
136    fn quantized_act_type(&self) -> Option<DType> {
137        None
138    }
139
140    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
141        Ok(Arc::new(Self {
142            w: (&self.w + delta)?,
143            b: self.b.clone(),
144            stats: self.stats.clone(),
145        }))
146    }
147
148    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
149        (self.w.dtype(), self.w.device().clone())
150    }
151
152    fn apply_isq(
153        self: Arc<Self>,
154        dtype: Option<IsqType>,
155        device: Device,
156        n_quantized: &AtomicUsize,
157        imatrix_weight: Option<Vec<f32>>,
158        guard: QuantizeOntoGuard,
159    ) -> Result<Arc<dyn QuantMethod>> {
160        match dtype {
161            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
162            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
163                let _acquired_quantize_guard = guard.acquire(&device);
164                if imatrix_weight.is_some() {
165                    // TODO just warn?
166                    candle_core::bail!("HQQ does not support imatrix.");
167                }
168
169                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
170                let bits = match dtype.unwrap() {
171                    IsqType::HQQ8 => HqqBits::Eight,
172                    IsqType::HQQ4 => HqqBits::Four,
173                    // IsqType::HQQ3 => HqqBits::Three,
174                    // IsqType::HQQ2 => HqqBits::Two,
175                    // IsqType::HQQ1 => HqqBits::One,
176                    _ => unreachable!(),
177                };
178                let cfg = HqqConfig {
179                    bits,
180                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
181                    axis: HqqAxis::Zero,
182                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
183                    round_zeros: false,
184                    channel_wise: true,
185                };
186                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
187                if let Some(bias) = &self.b {
188                    let bias = bias
189                        .to_device(&device)?
190                        .to_dtype(res.dtype_and_device().0)?;
191                    Ok(Arc::new(res.with_bias(bias)))
192                } else {
193                    Ok(Arc::new(res))
194                }
195            }
196            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
197                let _acquired_quantize_guard = guard.acquire(&device);
198                if imatrix_weight.is_some() {
199                    // TODO just warn?
200                    candle_core::bail!("AFQ does not support imatrix.");
201                }
202
203                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
204                let bits = match dtype.unwrap() {
205                    IsqType::AFQ8 => AfqBits::Eight,
206                    IsqType::AFQ6 => AfqBits::Six,
207                    IsqType::AFQ4 => AfqBits::Four,
208                    IsqType::AFQ3 => AfqBits::Three,
209                    IsqType::AFQ2 => AfqBits::Two,
210                    _ => unreachable!(),
211                };
212
213                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
214                    weight: self.w.to_device(&device)?,
215                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
216                    bits,
217                    group_size: AfqGroupSize::default(),
218                })?))
219            }
220            Some(
221                IsqType::Q2K
222                | IsqType::Q3K
223                | IsqType::Q4K
224                | IsqType::Q4_0
225                | IsqType::Q4_1
226                | IsqType::Q5K
227                | IsqType::Q5_0
228                | IsqType::Q5_1
229                | IsqType::Q6K
230                | IsqType::Q8K
231                | IsqType::Q8_0
232                | IsqType::Q8_1,
233            ) => {
234                let dtype: GgmlDType = dtype.unwrap().try_into()?;
235                let res = if let Some(imatrix_weight) = imatrix_weight {
236                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
237                } else {
238                    generate_isq!(self.w, device, dtype, n_quantized, guard)
239                };
240                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
241                    q_weight: res,
242                    b: self
243                        .b
244                        .as_ref()
245                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
246                })?))
247            }
248            Some(IsqType::F8E4M3) => {
249                let _acquired_quantize_guard = guard.acquire(&device);
250                if imatrix_weight.is_some() {
251                    // TODO just warn?
252                    candle_core::bail!("F8E4M3 does not support imatrix.");
253                }
254
255                let w = self.w.to_device(&device)?;
256                let b = if let Some(b) = &self.b {
257                    Some(b.to_device(&device)?)
258                } else {
259                    None
260                };
261                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
262                    lin: Linear::new(w, b),
263                    dtype: DType::F8E4M3,
264                })?))
265            }
266            None => {
267                let _acquired_quantize_guard = guard.acquire(&device);
268                // Ignore imatrix altogether
269
270                let w = self.w.to_device(&device)?;
271                let b = if let Some(b) = &self.b {
272                    Some(b.to_device(&device)?)
273                } else {
274                    None
275                };
276                Ok(Arc::new(UnquantLinear::new(
277                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
278                )?))
279            }
280        }
281    }
282
283    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
284        Some((self.w.clone(), self.b.clone()))
285    }
286
287    fn begin_track_stats(&mut self) -> Result<()> {
288        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
289        Ok(())
290    }
291
292    fn end_track_stats(&self) -> Result<Tensor> {
293        if let Some(stats) = &self.stats {
294            let imatrix = stats.compute_imatrix()?;
295            stats.clear()?;
296            Ok(imatrix)
297        } else {
298            candle_core::bail!("`{}` does not support tracking stats.", self.name())
299        }
300    }
301}
302
303// Serialization structure:
304//
305// -----------------------
306// UQFF version, u32, little endian
307// -----------------------
308// ISQ type (1 for unquantized), u8, little endian
309// -----------------------
310// Whether bias data is included, u8 boolean
311// -----------------------
312// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
313// -----------------------
314// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
315// -----------------------
316
317impl QuantizedSerde for UnquantLinear {
318    fn isq_serde_supported(&self) -> bool {
319        true
320    }
321    fn name(&self) -> &'static str {
322        "unquant-linear"
323    }
324    fn serialize(&self) -> Result<Cow<[u8]>> {
325        self.serialize_with_bias(self.b.clone())
326    }
327    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
328        let mut buffer = Vec::new();
329
330        // Version is always first!
331
332        buffer.extend(&UQFF_VERSION.to_le_bytes());
333
334        // ISQ type for unquant is 1
335        buffer.push(QuantizedSerdeType::Unquant as u8);
336
337        // Has bias
338        buffer.push(bias.is_some() as u8);
339
340        // Weight
341        serialize_tensor(&mut buffer, &self.w)?;
342
343        if let Some(bias) = &bias {
344            // Bias
345            serialize_tensor(&mut buffer, bias)?;
346        }
347
348        Ok(Cow::from(buffer))
349    }
350
351    fn deserialize(
352        data: Cow<[u8]>,
353        device: &Device,
354        _comm: &Arc<crate::Comm>,
355        guard: QuantizeOntoGuard,
356    ) -> Result<Arc<dyn QuantMethod>>
357    where
358        Self: Sized,
359    {
360        let mut buffer = Cursor::new(data);
361
362        let version = buffer.read_u32::<LittleEndian>()?;
363        if let Err(e) = version_is_compatible(version) {
364            return Err(candle_core::Error::wrap(e));
365        }
366
367        let isq_type = buffer.read_u8()? as usize;
368        if isq_type != QuantizedSerdeType::Unquant as usize {
369            candle_core::bail!(
370                "ISQ type ({isq_type}) doesn't match expected type {}",
371                QuantizedSerdeType::Unquant as usize
372            );
373        }
374
375        let has_bias = buffer.read_u8()? != 0;
376
377        let _acquired_load_guard = guard.acquire(device);
378        let w = deserialize_tensor(&mut buffer, device)?;
379
380        let b = if has_bias {
381            Some(deserialize_tensor(&mut buffer, device)?)
382        } else {
383            None
384        };
385
386        Ok(Arc::new(Self { w, b, stats: None }))
387    }
388    fn deserialize_ext_bias(
389        data: Cow<[u8]>,
390        device: &Device,
391        guard: QuantizeOntoGuard,
392    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
393    where
394        Self: Sized,
395    {
396        let mut buffer = Cursor::new(data);
397
398        let version = buffer.read_u32::<LittleEndian>()?;
399        if let Err(e) = version_is_compatible(version) {
400            return Err(candle_core::Error::wrap(e));
401        }
402
403        let isq_type = buffer.read_u8()? as usize;
404        if isq_type != QuantizedSerdeType::Unquant as usize {
405            candle_core::bail!(
406                "ISQ type ({isq_type}) doesn't match expected type {}",
407                QuantizedSerdeType::Unquant as usize
408            );
409        }
410
411        let has_bias = buffer.read_u8()? != 0;
412
413        let _acquired_load_guard = guard.acquire(device);
414        let w = deserialize_tensor(&mut buffer, device)?;
415
416        let b = if has_bias {
417            Some(deserialize_tensor(&mut buffer, device)?)
418        } else {
419            None
420        };
421
422        Ok((
423            Arc::new(Self {
424                w,
425                b: None,
426                stats: None,
427            }),
428            b,
429        ))
430    }
431}