mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Afq { .. }
41            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42            QuantMethodConfig::Unquantized(l) => Ok(Self {
43                w: l.weight().clone(),
44                b: l.bias().cloned(),
45                stats: None,
46            }),
47        }
48    }
49
50    fn dequantize_w(&self) -> Result<Tensor> {
51        Ok(self.w.clone())
52    }
53
54    fn forward(&self, a: &Tensor) -> Result<Tensor> {
55        // Batch matrix multiplication
56        maybe_init_cublas_lt_wrapper(a.device().clone());
57
58        let w = match *a.dims() {
59            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
60            [bsize, _, _] => self.w.broadcast_left(bsize)?,
61            _ => self.w.clone(),
62        };
63
64        if let Some(stats) = &self.stats {
65            stats.process(a)?;
66        }
67
68        if let Some(b) = self.b.as_ref() {
69            let mut tgt_shape = a.dims().to_vec();
70            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
71            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
72
73            match a.device().location() {
74                DeviceLocation::Cuda { .. } => {
75                    // Try to use cublaslt, otherwise fallback to gemm
76                    if let (Device::Cuda(_), Some(cublaslt)) =
77                        (a.device(), CUBLASLT_CONTROLLER.get())
78                    {
79                        cublaslt
80                            .batch_matmul(
81                                a,
82                                &w,
83                                Some(&b.t()?.contiguous()?),
84                                None,
85                                Some(1.0),
86                                None,
87                                None,
88                            )?
89                            .t()
90                    } else {
91                        let matmul_result = a.matmul(&w.t()?)?;
92                        matmul_result.broadcast_add(&b)
93                    }
94                }
95                DeviceLocation::Metal { .. } => {
96                    let matmul_result = a.matmul(&w.t()?)?;
97                    matmul_result.broadcast_add(&b)
98                }
99                DeviceLocation::Cpu => {
100                    #[cfg(feature = "accelerate")]
101                    {
102                        let original_dtype = a.dtype();
103                        let a_f32 = a.to_dtype(DType::F32)?;
104                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
105                        let b_f32 = b.to_dtype(DType::F32)?;
106                        let matmul_result = a_f32.matmul(&w_f32)?;
107                        matmul_result
108                            .broadcast_add(&b_f32)?
109                            .to_dtype(original_dtype)
110                    }
111                    #[cfg(not(feature = "accelerate"))]
112                    {
113                        let matmul_result = a.matmul(&w.t()?)?;
114                        matmul_result.broadcast_add(&b)
115                    }
116                }
117            }
118        } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
119            // cuBLAS batch_matmul requires 3D tensors, fall back to regular matmul for 2D
120            if a.rank() >= 3 && w.rank() >= 3 {
121                cublaslt
122                    .batch_matmul(a, &w, None, None, None, None, None)?
123                    .t()
124            } else {
125                MatMul.matmul(a, &w.t()?)
126            }
127        } else {
128            MatMul.matmul(a, &w.t()?)
129        }
130    }
131
132    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
133        // Weights are [num_experts, out_features, in_features]
134        // For Metal path:
135        //   - a: (b_size, seq_len, 1, 1, hidden_dim) - 5D
136        //   - indices: (b_size, seq_len, num_experts_per_tok) - 3D
137        // For CUDA path:
138        //   - a: (num_tokens, 1, hidden_dim) - 3D
139        //   - indices: (num_tokens, num_experts_per_tok) - 2D
140
141        let w = &self.w;
142        let (_num_experts, out_features, _in_features) = w.dims3()?;
143
144        match a.dims() {
145            // Metal path: 5D input (b_size, seq_len, 1, 1, hidden_dim)
146            &[b_size, seq_len, 1, 1, hidden_dim] => {
147                let (_b, _s, num_experts_per_tok) = indices.dims3()?;
148                // Flatten indices to select experts
149                let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
150
151                // Select expert weights: [b*s*k, out_features, in_features]
152                let selected_w = w.index_select(&flat_indices, 0)?;
153
154                // Reshape input: [b*s, hidden_dim]
155                let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
156
157                // For each token, we need to compute with each selected expert
158                // Broadcast a to match: [b*s, 1, hidden_dim] -> [b*s, k, hidden_dim]
159                let a_expanded = a_flat
160                    .unsqueeze(1)?
161                    .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
162                    .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
163
164                // Matmul: [b*s*k, hidden_dim] @ [b*s*k, hidden_dim, out_features] -> [b*s*k, out_features]
165                let result = a_expanded
166                    .unsqueeze(1)?
167                    .matmul(&selected_w.transpose(1, 2)?)?
168                    .squeeze(1)?;
169
170                // Reshape back to [b, s, k, out_features]
171                result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
172            }
173            // CUDA path: 3D input (num_tokens, 1, hidden_dim)
174            &[num_tokens, 1, hidden_dim] => {
175                let (_, num_experts_per_tok) = indices.dims2()?;
176
177                // Flatten indices
178                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
179
180                // Select expert weights: [n*k, out_features, in_features]
181                let selected_w = w.index_select(&flat_indices, 0)?;
182
183                // Broadcast input: [n, 1, hidden] -> [n, k, hidden] -> [n*k, hidden]
184                let a_expanded = a
185                    .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
186                    .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
187
188                // Matmul: [n*k, hidden] @ [n*k, hidden, out] -> [n*k, out]
189                let result = a_expanded
190                    .unsqueeze(1)?
191                    .matmul(&selected_w.transpose(1, 2)?)?
192                    .squeeze(1)?;
193
194                // Reshape to [n, k, out]
195                result.reshape((num_tokens, num_experts_per_tok, out_features))
196            }
197            dims => {
198                candle_core::bail!(
199                    "UnquantLinear::gather_forward: unsupported input shape {:?}",
200                    dims
201                );
202            }
203        }
204    }
205
206    fn quantized_act_type(&self) -> Option<DType> {
207        None
208    }
209
210    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
211        Ok(Arc::new(Self {
212            w: (&self.w + delta)?,
213            b: self.b.clone(),
214            stats: self.stats.clone(),
215        }))
216    }
217
218    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
219        (self.w.dtype(), self.w.device().clone())
220    }
221
222    fn apply_isq(
223        self: Arc<Self>,
224        dtype: Option<IsqType>,
225        device: Device,
226        n_quantized: &AtomicUsize,
227        imatrix_weight: Option<Vec<f32>>,
228        guard: QuantizeOntoGuard,
229    ) -> Result<Arc<dyn QuantMethod>> {
230        match dtype {
231            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
232            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
233                let _acquired_quantize_guard = guard.acquire(&device);
234                if imatrix_weight.is_some() {
235                    // TODO just warn?
236                    candle_core::bail!("HQQ does not support imatrix.");
237                }
238
239                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
240                let bits = match dtype.unwrap() {
241                    IsqType::HQQ8 => HqqBits::Eight,
242                    IsqType::HQQ4 => HqqBits::Four,
243                    // IsqType::HQQ3 => HqqBits::Three,
244                    // IsqType::HQQ2 => HqqBits::Two,
245                    // IsqType::HQQ1 => HqqBits::One,
246                    _ => unreachable!(),
247                };
248                let cfg = HqqConfig {
249                    bits,
250                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
251                    axis: HqqAxis::Zero,
252                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
253                    round_zeros: false,
254                    channel_wise: true,
255                };
256                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
257                if let Some(bias) = &self.b {
258                    let bias = bias
259                        .to_device(&device)?
260                        .to_dtype(res.dtype_and_device().0)?;
261                    Ok(Arc::new(res.with_bias(bias)))
262                } else {
263                    Ok(Arc::new(res))
264                }
265            }
266            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
267                let _acquired_quantize_guard = guard.acquire(&device);
268                if imatrix_weight.is_some() {
269                    // TODO just warn?
270                    candle_core::bail!("AFQ does not support imatrix.");
271                }
272
273                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274                let bits = match dtype.unwrap() {
275                    IsqType::AFQ8 => AfqBits::Eight,
276                    IsqType::AFQ6 => AfqBits::Six,
277                    IsqType::AFQ4 => AfqBits::Four,
278                    IsqType::AFQ3 => AfqBits::Three,
279                    IsqType::AFQ2 => AfqBits::Two,
280                    _ => unreachable!(),
281                };
282
283                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
284                    weight: self.w.to_device(&device)?,
285                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
286                    bits,
287                    group_size: AfqGroupSize::default(),
288                })?))
289            }
290            Some(
291                IsqType::Q2K
292                | IsqType::Q3K
293                | IsqType::Q4K
294                | IsqType::Q4_0
295                | IsqType::Q4_1
296                | IsqType::Q5K
297                | IsqType::Q5_0
298                | IsqType::Q5_1
299                | IsqType::Q6K
300                | IsqType::Q8K
301                | IsqType::Q8_0
302                | IsqType::Q8_1,
303            ) => {
304                let dtype: GgmlDType = dtype.unwrap().try_into()?;
305                let res = if let Some(imatrix_weight) = imatrix_weight {
306                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
307                } else {
308                    generate_isq!(self.w, device, dtype, n_quantized, guard)
309                };
310                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311                    q_weight: res,
312                    b: self
313                        .b
314                        .as_ref()
315                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
316                })?))
317            }
318            Some(IsqType::F8E4M3) => {
319                let _acquired_quantize_guard = guard.acquire(&device);
320                if imatrix_weight.is_some() {
321                    // TODO just warn?
322                    candle_core::bail!("F8E4M3 does not support imatrix.");
323                }
324
325                let w = self.w.to_device(&device)?;
326                let b = if let Some(b) = &self.b {
327                    Some(b.to_device(&device)?)
328                } else {
329                    None
330                };
331                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
332                    lin: Linear::new(w, b),
333                    dtype: DType::F8E4M3,
334                })?))
335            }
336            None => {
337                let _acquired_quantize_guard = guard.acquire(&device);
338                // Ignore imatrix altogether
339
340                let w = self.w.to_device(&device)?;
341                let b = if let Some(b) = &self.b {
342                    Some(b.to_device(&device)?)
343                } else {
344                    None
345                };
346                Ok(Arc::new(UnquantLinear::new(
347                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
348                )?))
349            }
350        }
351    }
352
353    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
354        Some((self.w.clone(), self.b.clone()))
355    }
356
357    fn begin_track_stats(&mut self) -> Result<()> {
358        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
359        Ok(())
360    }
361
362    fn end_track_stats(&self) -> Result<Tensor> {
363        if let Some(stats) = &self.stats {
364            let imatrix = stats.compute_imatrix()?;
365            stats.clear()?;
366            Ok(imatrix)
367        } else {
368            candle_core::bail!("`{}` does not support tracking stats.", self.name())
369        }
370    }
371}
372
373// Serialization structure:
374//
375// -----------------------
376// UQFF version, u32, little endian
377// -----------------------
378// ISQ type (1 for unquantized), u8, little endian
379// -----------------------
380// Whether bias data is included, u8 boolean
381// -----------------------
382// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
383// -----------------------
384// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
385// -----------------------
386
387impl QuantizedSerde for UnquantLinear {
388    fn isq_serde_supported(&self) -> bool {
389        true
390    }
391    fn name(&self) -> &'static str {
392        "unquant-linear"
393    }
394    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
395        self.serialize_with_bias(self.b.clone())
396    }
397    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
398        let mut buffer = Vec::new();
399
400        // Version is always first!
401
402        buffer.extend(&UQFF_VERSION.to_le_bytes());
403
404        // ISQ type for unquant is 1
405        buffer.push(QuantizedSerdeType::Unquant as u8);
406
407        // Has bias
408        buffer.push(bias.is_some() as u8);
409
410        // Weight
411        serialize_tensor(&mut buffer, &self.w)?;
412
413        if let Some(bias) = &bias {
414            // Bias
415            serialize_tensor(&mut buffer, bias)?;
416        }
417
418        Ok(Cow::from(buffer))
419    }
420
421    fn deserialize(
422        data: Cow<[u8]>,
423        device: &Device,
424        _comm: &Arc<crate::Comm>,
425        guard: QuantizeOntoGuard,
426    ) -> Result<Arc<dyn QuantMethod>>
427    where
428        Self: Sized,
429    {
430        let mut buffer = Cursor::new(data);
431
432        let version = buffer.read_u32::<LittleEndian>()?;
433        if let Err(e) = version_is_compatible(version) {
434            return Err(candle_core::Error::wrap(e));
435        }
436
437        let isq_type = buffer.read_u8()? as usize;
438        if isq_type != QuantizedSerdeType::Unquant as usize {
439            candle_core::bail!(
440                "ISQ type ({isq_type}) doesn't match expected type {}",
441                QuantizedSerdeType::Unquant as usize
442            );
443        }
444
445        let has_bias = buffer.read_u8()? != 0;
446
447        let _acquired_load_guard = guard.acquire(device);
448        let w = deserialize_tensor(&mut buffer, device)?;
449
450        let b = if has_bias {
451            Some(deserialize_tensor(&mut buffer, device)?)
452        } else {
453            None
454        };
455
456        Ok(Arc::new(Self { w, b, stats: None }))
457    }
458    fn deserialize_ext_bias(
459        data: Cow<[u8]>,
460        device: &Device,
461        guard: QuantizeOntoGuard,
462    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
463    where
464        Self: Sized,
465    {
466        let mut buffer = Cursor::new(data);
467
468        let version = buffer.read_u32::<LittleEndian>()?;
469        if let Err(e) = version_is_compatible(version) {
470            return Err(candle_core::Error::wrap(e));
471        }
472
473        let isq_type = buffer.read_u8()? as usize;
474        if isq_type != QuantizedSerdeType::Unquant as usize {
475            candle_core::bail!(
476                "ISQ type ({isq_type}) doesn't match expected type {}",
477                QuantizedSerdeType::Unquant as usize
478            );
479        }
480
481        let has_bias = buffer.read_u8()? != 0;
482
483        let _acquired_load_guard = guard.acquire(device);
484        let w = deserialize_tensor(&mut buffer, device)?;
485
486        let b = if has_bias {
487            Some(deserialize_tensor(&mut buffer, device)?)
488        } else {
489            None
490        };
491
492        Ok((
493            Arc::new(Self {
494                w,
495                b: None,
496                stats: None,
497            }),
498            b,
499        ))
500    }
501}