mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::PerTensorFP8 { .. }
41            | QuantMethodConfig::Afq { .. }
42            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43            QuantMethodConfig::Unquantized(l) => Ok(Self {
44                w: l.weight().clone(),
45                b: l.bias().cloned(),
46                stats: None,
47            }),
48        }
49    }
50
51    fn dequantize_w(&self) -> Result<Tensor> {
52        Ok(self.w.clone())
53    }
54
55    fn forward(&self, a: &Tensor) -> Result<Tensor> {
56        // Batch matrix multiplication
57        maybe_init_cublas_lt_wrapper(a.device().clone());
58
59        // Try custom GEMV for single-token decode (batch_size=1)
60        #[cfg(feature = "cuda")]
61        if crate::gemv::should_use_gemv(a, &self.w) {
62            return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63        }
64
65        let w = match *a.dims() {
66            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67            [bsize, _, _] => self.w.broadcast_left(bsize)?,
68            _ => self.w.clone(),
69        };
70
71        if let Some(stats) = &self.stats {
72            stats.process(a)?;
73        }
74
75        if let Some(b) = self.b.as_ref() {
76            let mut tgt_shape = a.dims().to_vec();
77            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80            match a.device().location() {
81                DeviceLocation::Cuda { .. } => {
82                    // Try to use cublaslt, otherwise fallback to gemm
83                    if let (Device::Cuda(_), Some(cublaslt)) =
84                        (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85                    {
86                        cublaslt
87                            .batch_matmul(
88                                a,
89                                &w,
90                                Some(&b.t()?.contiguous()?),
91                                None,
92                                Some(1.0),
93                                None,
94                                None,
95                            )?
96                            .t()
97                    } else {
98                        let matmul_result = a.matmul(&w.t()?)?;
99                        matmul_result.broadcast_add(&b)
100                    }
101                }
102                DeviceLocation::Metal { .. } => {
103                    let matmul_result = a.matmul(&w.t()?)?;
104                    matmul_result.broadcast_add(&b)
105                }
106                DeviceLocation::Cpu => {
107                    #[cfg(feature = "accelerate")]
108                    {
109                        let original_dtype = a.dtype();
110                        let a_f32 = a.to_dtype(DType::F32)?;
111                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
112                        let b_f32 = b.to_dtype(DType::F32)?;
113                        let matmul_result = a_f32.matmul(&w_f32)?;
114                        matmul_result
115                            .broadcast_add(&b_f32)?
116                            .to_dtype(original_dtype)
117                    }
118                    #[cfg(not(feature = "accelerate"))]
119                    {
120                        let matmul_result = a.matmul(&w.t()?)?;
121                        matmul_result.broadcast_add(&b)
122                    }
123                }
124            }
125        } else if let (Device::Cuda(_), Some(cublaslt)) =
126            (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
127        {
128            // cuBLAS batch_matmul requires 3D tensors, fall back to regular matmul for 2D
129            if a.rank() >= 3 && w.rank() >= 3 {
130                cublaslt
131                    .batch_matmul(a, &w, None, None, None, None, None)?
132                    .t()
133            } else {
134                MatMul.matmul(a, &w.t()?)
135            }
136        } else {
137            MatMul.matmul(a, &w.t()?)
138        }
139    }
140
141    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
142        // Weights are [num_experts, out_features, in_features]
143        // For Metal path:
144        //   - a: (b_size, seq_len, 1, 1, hidden_dim) - 5D
145        //   - indices: (b_size, seq_len, num_experts_per_tok) - 3D
146        // For CUDA path:
147        //   - a: (num_tokens, 1, hidden_dim) - 3D
148        //   - indices: (num_tokens, num_experts_per_tok) - 2D
149
150        let w = &self.w;
151        let (_num_experts, out_features, _in_features) = w.dims3()?;
152
153        match a.dims() {
154            // Metal path: 5D input (b_size, seq_len, 1, 1, hidden_dim)
155            &[b_size, seq_len, 1, 1, hidden_dim] => {
156                let (_b, _s, num_experts_per_tok) = indices.dims3()?;
157                // Flatten indices to select experts
158                let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
159
160                // Select expert weights: [b*s*k, out_features, in_features]
161                let selected_w = w.index_select(&flat_indices, 0)?;
162
163                // Reshape input: [b*s, hidden_dim]
164                let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
165
166                // For each token, we need to compute with each selected expert
167                // Broadcast a to match: [b*s, 1, hidden_dim] -> [b*s, k, hidden_dim]
168                let a_expanded = a_flat
169                    .unsqueeze(1)?
170                    .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
171                    .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
172
173                // Matmul: [b*s*k, hidden_dim] @ [b*s*k, hidden_dim, out_features] -> [b*s*k, out_features]
174                let result = a_expanded
175                    .unsqueeze(1)?
176                    .matmul(&selected_w.transpose(1, 2)?)?
177                    .squeeze(1)?;
178
179                // Reshape back to [b, s, k, out_features]
180                result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
181            }
182            // CUDA path: 3D input (num_tokens, 1, hidden_dim)
183            &[num_tokens, 1, hidden_dim] => {
184                let (_, num_experts_per_tok) = indices.dims2()?;
185
186                // Flatten indices
187                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
188
189                // Select expert weights: [n*k, out_features, in_features]
190                let selected_w = w.index_select(&flat_indices, 0)?;
191
192                // Broadcast input: [n, 1, hidden] -> [n, k, hidden] -> [n*k, hidden]
193                let a_expanded = a
194                    .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
195                    .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
196
197                // Matmul: [n*k, hidden] @ [n*k, hidden, out] -> [n*k, out]
198                let result = a_expanded
199                    .unsqueeze(1)?
200                    .matmul(&selected_w.transpose(1, 2)?)?
201                    .squeeze(1)?;
202
203                // Reshape to [n, k, out]
204                result.reshape((num_tokens, num_experts_per_tok, out_features))
205            }
206            dims => {
207                candle_core::bail!(
208                    "UnquantLinear::gather_forward: unsupported input shape {:?}",
209                    dims
210                );
211            }
212        }
213    }
214
215    fn quantized_act_type(&self) -> Option<DType> {
216        None
217    }
218
219    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
220        Ok(Arc::new(Self {
221            w: (&self.w + delta)?,
222            b: self.b.clone(),
223            stats: self.stats.clone(),
224        }))
225    }
226
227    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
228        (self.w.dtype(), self.w.device().clone())
229    }
230
231    fn apply_isq(
232        self: Arc<Self>,
233        dtype: Option<IsqType>,
234        device: Device,
235        n_quantized: &AtomicUsize,
236        imatrix_weight: Option<Vec<f32>>,
237        guard: QuantizeOntoGuard,
238    ) -> Result<Arc<dyn QuantMethod>> {
239        match dtype {
240            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
241            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
242                let _acquired_quantize_guard = guard.acquire(&device);
243                if imatrix_weight.is_some() {
244                    // TODO just warn?
245                    candle_core::bail!("HQQ does not support imatrix.");
246                }
247
248                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
249                let bits = match dtype.unwrap() {
250                    IsqType::HQQ8 => HqqBits::Eight,
251                    IsqType::HQQ4 => HqqBits::Four,
252                    // IsqType::HQQ3 => HqqBits::Three,
253                    // IsqType::HQQ2 => HqqBits::Two,
254                    // IsqType::HQQ1 => HqqBits::One,
255                    _ => unreachable!(),
256                };
257                let cfg = HqqConfig {
258                    bits,
259                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
260                    axis: HqqAxis::Zero,
261                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
262                    round_zeros: false,
263                    channel_wise: true,
264                };
265                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
266                if let Some(bias) = &self.b {
267                    let bias = bias
268                        .to_device(&device)?
269                        .to_dtype(res.dtype_and_device().0)?;
270                    Ok(Arc::new(res.with_bias(bias)))
271                } else {
272                    Ok(Arc::new(res))
273                }
274            }
275            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
276                let _acquired_quantize_guard = guard.acquire(&device);
277                if imatrix_weight.is_some() {
278                    // TODO just warn?
279                    candle_core::bail!("AFQ does not support imatrix.");
280                }
281
282                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
283                let bits = match dtype.unwrap() {
284                    IsqType::AFQ8 => AfqBits::Eight,
285                    IsqType::AFQ6 => AfqBits::Six,
286                    IsqType::AFQ4 => AfqBits::Four,
287                    IsqType::AFQ3 => AfqBits::Three,
288                    IsqType::AFQ2 => AfqBits::Two,
289                    _ => unreachable!(),
290                };
291
292                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
293                    weight: self.w.to_device(&device)?,
294                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
295                    bits,
296                    group_size: AfqGroupSize::default(),
297                })?))
298            }
299            Some(
300                IsqType::Q2K
301                | IsqType::Q3K
302                | IsqType::Q4K
303                | IsqType::Q4_0
304                | IsqType::Q4_1
305                | IsqType::Q5K
306                | IsqType::Q5_0
307                | IsqType::Q5_1
308                | IsqType::Q6K
309                | IsqType::Q8K
310                | IsqType::Q8_0
311                | IsqType::Q8_1,
312            ) => {
313                let dtype: GgmlDType = dtype.unwrap().try_into()?;
314                let res = if let Some(imatrix_weight) = imatrix_weight {
315                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
316                } else {
317                    generate_isq!(self.w, device, dtype, n_quantized, guard)
318                };
319                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320                    q_weight: res,
321                    b: self
322                        .b
323                        .as_ref()
324                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
325                })?))
326            }
327            Some(IsqType::F8E4M3) => {
328                let _acquired_quantize_guard = guard.acquire(&device);
329                if imatrix_weight.is_some() {
330                    // TODO just warn?
331                    candle_core::bail!("F8E4M3 does not support imatrix.");
332                }
333
334                let w = self.w.to_device(&device)?;
335                let b = if let Some(b) = &self.b {
336                    Some(b.to_device(&device)?)
337                } else {
338                    None
339                };
340                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
341                    lin: Linear::new(w, b),
342                    dtype: DType::F8E4M3,
343                })?))
344            }
345            None => {
346                let _acquired_quantize_guard = guard.acquire(&device);
347                // Ignore imatrix altogether
348
349                let w = self.w.to_device(&device)?;
350                let b = if let Some(b) = &self.b {
351                    Some(b.to_device(&device)?)
352                } else {
353                    None
354                };
355                Ok(Arc::new(UnquantLinear::new(
356                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
357                )?))
358            }
359        }
360    }
361
362    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
363        Some((self.w.clone(), self.b.clone()))
364    }
365
366    fn begin_track_stats(&mut self) -> Result<()> {
367        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
368        Ok(())
369    }
370
371    fn end_track_stats(&self) -> Result<Tensor> {
372        if let Some(stats) = &self.stats {
373            let imatrix = stats.compute_imatrix()?;
374            stats.clear()?;
375            Ok(imatrix)
376        } else {
377            candle_core::bail!("`{}` does not support tracking stats.", self.name())
378        }
379    }
380}
381
382// Serialization structure:
383//
384// -----------------------
385// UQFF version, u32, little endian
386// -----------------------
387// ISQ type (1 for unquantized), u8, little endian
388// -----------------------
389// Whether bias data is included, u8 boolean
390// -----------------------
391// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
392// -----------------------
393// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
394// -----------------------
395
396impl QuantizedSerde for UnquantLinear {
397    fn isq_serde_supported(&self) -> bool {
398        true
399    }
400    fn name(&self) -> &'static str {
401        "unquant-linear"
402    }
403    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
404        self.serialize_with_bias(self.b.clone())
405    }
406    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
407        let mut buffer = Vec::new();
408
409        // Version is always first!
410
411        buffer.extend(&UQFF_VERSION.to_le_bytes());
412
413        // ISQ type for unquant is 1
414        buffer.push(QuantizedSerdeType::Unquant as u8);
415
416        // Has bias
417        buffer.push(bias.is_some() as u8);
418
419        // Weight
420        serialize_tensor(&mut buffer, &self.w)?;
421
422        if let Some(bias) = &bias {
423            // Bias
424            serialize_tensor(&mut buffer, bias)?;
425        }
426
427        Ok(Cow::from(buffer))
428    }
429
430    fn deserialize(
431        data: Cow<[u8]>,
432        device: &Device,
433        _comm: &Arc<crate::Comm>,
434        guard: QuantizeOntoGuard,
435    ) -> Result<Arc<dyn QuantMethod>>
436    where
437        Self: Sized,
438    {
439        let mut buffer = Cursor::new(data);
440
441        let version = buffer.read_u32::<LittleEndian>()?;
442        if let Err(e) = version_is_compatible(version) {
443            return Err(candle_core::Error::wrap(e));
444        }
445
446        let isq_type = buffer.read_u8()? as usize;
447        if isq_type != QuantizedSerdeType::Unquant as usize {
448            candle_core::bail!(
449                "ISQ type ({isq_type}) doesn't match expected type {}",
450                QuantizedSerdeType::Unquant as usize
451            );
452        }
453
454        let has_bias = buffer.read_u8()? != 0;
455
456        let _acquired_load_guard = guard.acquire(device);
457        let w = deserialize_tensor(&mut buffer, device)?;
458
459        let b = if has_bias {
460            Some(deserialize_tensor(&mut buffer, device)?)
461        } else {
462            None
463        };
464
465        Ok(Arc::new(Self { w, b, stats: None }))
466    }
467    fn deserialize_ext_bias(
468        data: Cow<[u8]>,
469        device: &Device,
470        guard: QuantizeOntoGuard,
471    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
472    where
473        Self: Sized,
474    {
475        let mut buffer = Cursor::new(data);
476
477        let version = buffer.read_u32::<LittleEndian>()?;
478        if let Err(e) = version_is_compatible(version) {
479            return Err(candle_core::Error::wrap(e));
480        }
481
482        let isq_type = buffer.read_u8()? as usize;
483        if isq_type != QuantizedSerdeType::Unquant as usize {
484            candle_core::bail!(
485                "ISQ type ({isq_type}) doesn't match expected type {}",
486                QuantizedSerdeType::Unquant as usize
487            );
488        }
489
490        let has_bias = buffer.read_u8()? != 0;
491
492        let _acquired_load_guard = guard.acquire(device);
493        let w = deserialize_tensor(&mut buffer, device)?;
494
495        let b = if has_bias {
496            Some(deserialize_tensor(&mut buffer, device)?)
497        } else {
498            None
499        };
500
501        Ok((
502            Arc::new(Self {
503                w,
504                b: None,
505                stats: None,
506            }),
507            b,
508        ))
509    }
510}