mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. }
41            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42            QuantMethodConfig::Unquantized(l) => Ok(Self {
43                w: l.weight().clone(),
44                b: l.bias().cloned(),
45                stats: None,
46            }),
47        }
48    }
49
50    fn dequantize_w(&self) -> Result<Tensor> {
51        Ok(self.w.clone())
52    }
53
54    fn forward(&self, a: &Tensor) -> Result<Tensor> {
55        // Batch matrix multiplication
56        maybe_init_cublas_lt_wrapper(a.device().clone());
57
58        // Try custom GEMV for single-token decode (batch_size=1)
59        #[cfg(feature = "cuda")]
60        if crate::gemv::should_use_gemv(a, &self.w) {
61            return crate::gemv::gemv(a, &self.w, self.b.as_ref());
62        }
63
64        let w = match *a.dims() {
65            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
66            [bsize, _, _] => self.w.broadcast_left(bsize)?,
67            _ => self.w.clone(),
68        };
69
70        if let Some(stats) = &self.stats {
71            stats.process(a)?;
72        }
73
74        if let Some(b) = self.b.as_ref() {
75            let mut tgt_shape = a.dims().to_vec();
76            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
77            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
78
79            match a.device().location() {
80                DeviceLocation::Cuda { .. } => {
81                    // Try to use cublaslt, otherwise fallback to gemm
82                    if let (Device::Cuda(_), Some(cublaslt)) =
83                        (a.device(), CUBLASLT_CONTROLLER.get())
84                    {
85                        cublaslt
86                            .batch_matmul(
87                                a,
88                                &w,
89                                Some(&b.t()?.contiguous()?),
90                                None,
91                                Some(1.0),
92                                None,
93                                None,
94                            )?
95                            .t()
96                    } else {
97                        let matmul_result = a.matmul(&w.t()?)?;
98                        matmul_result.broadcast_add(&b)
99                    }
100                }
101                DeviceLocation::Metal { .. } => {
102                    let matmul_result = a.matmul(&w.t()?)?;
103                    matmul_result.broadcast_add(&b)
104                }
105                DeviceLocation::Cpu => {
106                    #[cfg(feature = "accelerate")]
107                    {
108                        let original_dtype = a.dtype();
109                        let a_f32 = a.to_dtype(DType::F32)?;
110                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
111                        let b_f32 = b.to_dtype(DType::F32)?;
112                        let matmul_result = a_f32.matmul(&w_f32)?;
113                        matmul_result
114                            .broadcast_add(&b_f32)?
115                            .to_dtype(original_dtype)
116                    }
117                    #[cfg(not(feature = "accelerate"))]
118                    {
119                        let matmul_result = a.matmul(&w.t()?)?;
120                        matmul_result.broadcast_add(&b)
121                    }
122                }
123            }
124        } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
125            // cuBLAS batch_matmul requires 3D tensors, fall back to regular matmul for 2D
126            if a.rank() >= 3 && w.rank() >= 3 {
127                cublaslt
128                    .batch_matmul(a, &w, None, None, None, None, None)?
129                    .t()
130            } else {
131                MatMul.matmul(a, &w.t()?)
132            }
133        } else {
134            MatMul.matmul(a, &w.t()?)
135        }
136    }
137
138    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
139        // Weights are [num_experts, out_features, in_features]
140        // For Metal path:
141        //   - a: (b_size, seq_len, 1, 1, hidden_dim) - 5D
142        //   - indices: (b_size, seq_len, num_experts_per_tok) - 3D
143        // For CUDA path:
144        //   - a: (num_tokens, 1, hidden_dim) - 3D
145        //   - indices: (num_tokens, num_experts_per_tok) - 2D
146
147        let w = &self.w;
148        let (_num_experts, out_features, _in_features) = w.dims3()?;
149
150        match a.dims() {
151            // Metal path: 5D input (b_size, seq_len, 1, 1, hidden_dim)
152            &[b_size, seq_len, 1, 1, hidden_dim] => {
153                let (_b, _s, num_experts_per_tok) = indices.dims3()?;
154                // Flatten indices to select experts
155                let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
156
157                // Select expert weights: [b*s*k, out_features, in_features]
158                let selected_w = w.index_select(&flat_indices, 0)?;
159
160                // Reshape input: [b*s, hidden_dim]
161                let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
162
163                // For each token, we need to compute with each selected expert
164                // Broadcast a to match: [b*s, 1, hidden_dim] -> [b*s, k, hidden_dim]
165                let a_expanded = a_flat
166                    .unsqueeze(1)?
167                    .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
168                    .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
169
170                // Matmul: [b*s*k, hidden_dim] @ [b*s*k, hidden_dim, out_features] -> [b*s*k, out_features]
171                let result = a_expanded
172                    .unsqueeze(1)?
173                    .matmul(&selected_w.transpose(1, 2)?)?
174                    .squeeze(1)?;
175
176                // Reshape back to [b, s, k, out_features]
177                result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
178            }
179            // CUDA path: 3D input (num_tokens, 1, hidden_dim)
180            &[num_tokens, 1, hidden_dim] => {
181                let (_, num_experts_per_tok) = indices.dims2()?;
182
183                // Flatten indices
184                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
185
186                // Select expert weights: [n*k, out_features, in_features]
187                let selected_w = w.index_select(&flat_indices, 0)?;
188
189                // Broadcast input: [n, 1, hidden] -> [n, k, hidden] -> [n*k, hidden]
190                let a_expanded = a
191                    .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
192                    .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
193
194                // Matmul: [n*k, hidden] @ [n*k, hidden, out] -> [n*k, out]
195                let result = a_expanded
196                    .unsqueeze(1)?
197                    .matmul(&selected_w.transpose(1, 2)?)?
198                    .squeeze(1)?;
199
200                // Reshape to [n, k, out]
201                result.reshape((num_tokens, num_experts_per_tok, out_features))
202            }
203            dims => {
204                candle_core::bail!(
205                    "UnquantLinear::gather_forward: unsupported input shape {:?}",
206                    dims
207                );
208            }
209        }
210    }
211
212    fn quantized_act_type(&self) -> Option<DType> {
213        None
214    }
215
216    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
217        Ok(Arc::new(Self {
218            w: (&self.w + delta)?,
219            b: self.b.clone(),
220            stats: self.stats.clone(),
221        }))
222    }
223
224    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
225        (self.w.dtype(), self.w.device().clone())
226    }
227
228    fn apply_isq(
229        self: Arc<Self>,
230        dtype: Option<IsqType>,
231        device: Device,
232        n_quantized: &AtomicUsize,
233        imatrix_weight: Option<Vec<f32>>,
234        guard: QuantizeOntoGuard,
235    ) -> Result<Arc<dyn QuantMethod>> {
236        match dtype {
237            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
238            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
239                let _acquired_quantize_guard = guard.acquire(&device);
240                if imatrix_weight.is_some() {
241                    // TODO just warn?
242                    candle_core::bail!("HQQ does not support imatrix.");
243                }
244
245                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
246                let bits = match dtype.unwrap() {
247                    IsqType::HQQ8 => HqqBits::Eight,
248                    IsqType::HQQ4 => HqqBits::Four,
249                    // IsqType::HQQ3 => HqqBits::Three,
250                    // IsqType::HQQ2 => HqqBits::Two,
251                    // IsqType::HQQ1 => HqqBits::One,
252                    _ => unreachable!(),
253                };
254                let cfg = HqqConfig {
255                    bits,
256                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
257                    axis: HqqAxis::Zero,
258                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
259                    round_zeros: false,
260                    channel_wise: true,
261                };
262                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
263                if let Some(bias) = &self.b {
264                    let bias = bias
265                        .to_device(&device)?
266                        .to_dtype(res.dtype_and_device().0)?;
267                    Ok(Arc::new(res.with_bias(bias)))
268                } else {
269                    Ok(Arc::new(res))
270                }
271            }
272            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
273                let _acquired_quantize_guard = guard.acquire(&device);
274                if imatrix_weight.is_some() {
275                    // TODO just warn?
276                    candle_core::bail!("AFQ does not support imatrix.");
277                }
278
279                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
280                let bits = match dtype.unwrap() {
281                    IsqType::AFQ8 => AfqBits::Eight,
282                    IsqType::AFQ6 => AfqBits::Six,
283                    IsqType::AFQ4 => AfqBits::Four,
284                    IsqType::AFQ3 => AfqBits::Three,
285                    IsqType::AFQ2 => AfqBits::Two,
286                    _ => unreachable!(),
287                };
288
289                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
290                    weight: self.w.to_device(&device)?,
291                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
292                    bits,
293                    group_size: AfqGroupSize::default(),
294                })?))
295            }
296            Some(
297                IsqType::Q2K
298                | IsqType::Q3K
299                | IsqType::Q4K
300                | IsqType::Q4_0
301                | IsqType::Q4_1
302                | IsqType::Q5K
303                | IsqType::Q5_0
304                | IsqType::Q5_1
305                | IsqType::Q6K
306                | IsqType::Q8K
307                | IsqType::Q8_0
308                | IsqType::Q8_1,
309            ) => {
310                let dtype: GgmlDType = dtype.unwrap().try_into()?;
311                let res = if let Some(imatrix_weight) = imatrix_weight {
312                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
313                } else {
314                    generate_isq!(self.w, device, dtype, n_quantized, guard)
315                };
316                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
317                    q_weight: res,
318                    b: self
319                        .b
320                        .as_ref()
321                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
322                })?))
323            }
324            Some(IsqType::F8E4M3) => {
325                let _acquired_quantize_guard = guard.acquire(&device);
326                if imatrix_weight.is_some() {
327                    // TODO just warn?
328                    candle_core::bail!("F8E4M3 does not support imatrix.");
329                }
330
331                let w = self.w.to_device(&device)?;
332                let b = if let Some(b) = &self.b {
333                    Some(b.to_device(&device)?)
334                } else {
335                    None
336                };
337                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
338                    lin: Linear::new(w, b),
339                    dtype: DType::F8E4M3,
340                })?))
341            }
342            None => {
343                let _acquired_quantize_guard = guard.acquire(&device);
344                // Ignore imatrix altogether
345
346                let w = self.w.to_device(&device)?;
347                let b = if let Some(b) = &self.b {
348                    Some(b.to_device(&device)?)
349                } else {
350                    None
351                };
352                Ok(Arc::new(UnquantLinear::new(
353                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
354                )?))
355            }
356        }
357    }
358
359    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
360        Some((self.w.clone(), self.b.clone()))
361    }
362
363    fn begin_track_stats(&mut self) -> Result<()> {
364        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
365        Ok(())
366    }
367
368    fn end_track_stats(&self) -> Result<Tensor> {
369        if let Some(stats) = &self.stats {
370            let imatrix = stats.compute_imatrix()?;
371            stats.clear()?;
372            Ok(imatrix)
373        } else {
374            candle_core::bail!("`{}` does not support tracking stats.", self.name())
375        }
376    }
377}
378
379// Serialization structure:
380//
381// -----------------------
382// UQFF version, u32, little endian
383// -----------------------
384// ISQ type (1 for unquantized), u8, little endian
385// -----------------------
386// Whether bias data is included, u8 boolean
387// -----------------------
388// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
389// -----------------------
390// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
391// -----------------------
392
393impl QuantizedSerde for UnquantLinear {
394    fn isq_serde_supported(&self) -> bool {
395        true
396    }
397    fn name(&self) -> &'static str {
398        "unquant-linear"
399    }
400    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
401        self.serialize_with_bias(self.b.clone())
402    }
403    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
404        let mut buffer = Vec::new();
405
406        // Version is always first!
407
408        buffer.extend(&UQFF_VERSION.to_le_bytes());
409
410        // ISQ type for unquant is 1
411        buffer.push(QuantizedSerdeType::Unquant as u8);
412
413        // Has bias
414        buffer.push(bias.is_some() as u8);
415
416        // Weight
417        serialize_tensor(&mut buffer, &self.w)?;
418
419        if let Some(bias) = &bias {
420            // Bias
421            serialize_tensor(&mut buffer, bias)?;
422        }
423
424        Ok(Cow::from(buffer))
425    }
426
427    fn deserialize(
428        data: Cow<[u8]>,
429        device: &Device,
430        _comm: &Arc<crate::Comm>,
431        guard: QuantizeOntoGuard,
432    ) -> Result<Arc<dyn QuantMethod>>
433    where
434        Self: Sized,
435    {
436        let mut buffer = Cursor::new(data);
437
438        let version = buffer.read_u32::<LittleEndian>()?;
439        if let Err(e) = version_is_compatible(version) {
440            return Err(candle_core::Error::wrap(e));
441        }
442
443        let isq_type = buffer.read_u8()? as usize;
444        if isq_type != QuantizedSerdeType::Unquant as usize {
445            candle_core::bail!(
446                "ISQ type ({isq_type}) doesn't match expected type {}",
447                QuantizedSerdeType::Unquant as usize
448            );
449        }
450
451        let has_bias = buffer.read_u8()? != 0;
452
453        let _acquired_load_guard = guard.acquire(device);
454        let w = deserialize_tensor(&mut buffer, device)?;
455
456        let b = if has_bias {
457            Some(deserialize_tensor(&mut buffer, device)?)
458        } else {
459            None
460        };
461
462        Ok(Arc::new(Self { w, b, stats: None }))
463    }
464    fn deserialize_ext_bias(
465        data: Cow<[u8]>,
466        device: &Device,
467        guard: QuantizeOntoGuard,
468    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
469    where
470        Self: Sized,
471    {
472        let mut buffer = Cursor::new(data);
473
474        let version = buffer.read_u32::<LittleEndian>()?;
475        if let Err(e) = version_is_compatible(version) {
476            return Err(candle_core::Error::wrap(e));
477        }
478
479        let isq_type = buffer.read_u8()? as usize;
480        if isq_type != QuantizedSerdeType::Unquant as usize {
481            candle_core::bail!(
482                "ISQ type ({isq_type}) doesn't match expected type {}",
483                QuantizedSerdeType::Unquant as usize
484            );
485        }
486
487        let has_bias = buffer.read_u8()? != 0;
488
489        let _acquired_load_guard = guard.acquire(device);
490        let w = deserialize_tensor(&mut buffer, device)?;
491
492        let b = if has_bias {
493            Some(deserialize_tensor(&mut buffer, device)?)
494        } else {
495            None
496        };
497
498        Ok((
499            Arc::new(Self {
500                w,
501                b: None,
502                stats: None,
503            }),
504            b,
505        ))
506    }
507}