mistralrs_quant/pertensor_fp8/
mod.rs

1use std::{
2    borrow::Cow,
3    sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use candle_nn::Linear;
8
9mod ops;
10
11use crate::{
12    generate_isq, generate_isq_imatrix,
13    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14    utils::{serialize_tensor, UQFF_VERSION},
15    AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
16    HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
17    QuantizedConfig, QuantizedSerde, QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
18};
19
20/// Per-tensor FP8 Linear layer with static activation scaling.
21///
22/// This is used for models that have per-tensor FP8 quantization (weight_block_size = null)
23/// with static activation scales. Each linear layer has:
24/// - `<layer>.weight` (FP8 E4M3)
25/// - `<layer>.weight_scale_inv` (F32 scalar) - dequantization scale for weights
26/// - `<layer>.activation_scale` (F32 scalar) - quantization scale for activations
27#[derive(Debug)]
28pub struct PerTensorFP8Linear {
29    weight: Tensor,
30    #[allow(dead_code)]
31    weight_scale_inv: Tensor,
32    #[allow(dead_code)]
33    activation_scale: Option<Tensor>,
34    bias: Option<Tensor>,
35    #[allow(dead_code)]
36    dequant_dtype: DType,
37}
38
39impl QuantMethod for PerTensorFP8Linear {
40    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
41    where
42        Self: Sized,
43    {
44        match method {
45            QuantMethodConfig::PerTensorFP8 {
46                weight,
47                weight_scale_inv,
48                activation_scale,
49                bias,
50                dequant_dtype,
51            } => {
52                // Dequantize immediately since Candle FP8 is storage-only (no ops)
53                let dequant_weight =
54                    ops::fp8_pertensor_dequantize(&weight, &weight_scale_inv, dequant_dtype)?;
55                Ok(Self {
56                    weight: dequant_weight,
57                    weight_scale_inv,
58                    activation_scale,
59                    bias,
60                    dequant_dtype,
61                })
62            }
63            _ => unreachable!(),
64        }
65    }
66
67    fn dequantize_w(&self) -> Result<Tensor> {
68        // Weight is already dequantized on load
69        Ok(self.weight.clone())
70    }
71
72    fn forward(&self, x: &Tensor) -> Result<Tensor> {
73        // Weight is already dequantized, use standard matmul
74        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
75            self.weight.clone(),
76            self.bias.clone(),
77        )))?;
78        unquant.forward(x)
79    }
80
81    fn quantized_act_type(&self) -> Option<DType> {
82        None
83    }
84
85    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
86        candle_core::bail!("PerTensorFP8Linear does not support add_delta_w")
87    }
88
89    fn dtype_and_device(&self) -> (DType, Device) {
90        (DType::F8E4M3, self.weight.device().clone())
91    }
92
93    fn apply_isq(
94        self: Arc<Self>,
95        dtype: Option<IsqType>,
96        device: Device,
97        n_quantized: &AtomicUsize,
98        imatrix_weight: Option<Vec<f32>>,
99        guard: QuantizeOntoGuard,
100    ) -> Result<Arc<dyn QuantMethod>> {
101        let weight = self.dequantize_w()?;
102        match dtype {
103            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
104                let _acquired_quantize_guard = guard.acquire(&device);
105                if imatrix_weight.is_some() {
106                    candle_core::bail!("HQQ does not support imatrix.");
107                }
108
109                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
110                let bits = match dtype.unwrap() {
111                    IsqType::HQQ8 => HqqBits::Eight,
112                    IsqType::HQQ4 => HqqBits::Four,
113                    _ => unreachable!(),
114                };
115                let cfg = HqqConfig {
116                    bits,
117                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
118                    axis: HqqAxis::Zero,
119                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
120                    round_zeros: false,
121                    channel_wise: true,
122                };
123                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
124                if let Some(bias) = &self.bias {
125                    let bias = bias
126                        .to_device(&device)?
127                        .to_dtype(res.dtype_and_device().0)?;
128                    Ok(Arc::new(res.with_bias(bias)))
129                } else {
130                    Ok(Arc::new(res))
131                }
132            }
133            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
134                let _acquired_quantize_guard = guard.acquire(&device);
135                if imatrix_weight.is_some() {
136                    candle_core::bail!("AFQ does not support imatrix.");
137                }
138
139                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
140                let bits = match dtype.unwrap() {
141                    IsqType::AFQ8 => AfqBits::Eight,
142                    IsqType::AFQ6 => AfqBits::Six,
143                    IsqType::AFQ4 => AfqBits::Four,
144                    IsqType::AFQ3 => AfqBits::Three,
145                    IsqType::AFQ2 => AfqBits::Two,
146                    _ => unreachable!(),
147                };
148
149                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
150                    weight: weight.to_device(&device)?,
151                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
152                    bits,
153                    group_size: AfqGroupSize::default(),
154                })?))
155            }
156            Some(
157                IsqType::Q2K
158                | IsqType::Q3K
159                | IsqType::Q4K
160                | IsqType::Q4_0
161                | IsqType::Q4_1
162                | IsqType::Q5K
163                | IsqType::Q5_0
164                | IsqType::Q5_1
165                | IsqType::Q6K
166                | IsqType::Q8K
167                | IsqType::Q8_0
168                | IsqType::Q8_1,
169            ) => {
170                let dtype: GgmlDType = dtype.unwrap().try_into()?;
171                let res = if let Some(imatrix_weight) = imatrix_weight {
172                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
173                } else {
174                    generate_isq!(weight, device, dtype, n_quantized, guard)
175                };
176                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
177                    q_weight: res,
178                    b: self
179                        .bias
180                        .as_ref()
181                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
182                })?))
183            }
184            Some(IsqType::F8E4M3) => {
185                let _acquired_quantize_guard = guard.acquire(&device);
186                if imatrix_weight.is_some() {
187                    candle_core::bail!("F8E4M3 does not support imatrix.");
188                }
189
190                let w = weight.to_device(&device)?;
191                let b = if let Some(b) = &self.bias {
192                    Some(b.to_device(&device)?)
193                } else {
194                    None
195                };
196                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
197                    lin: Linear::new(w, b),
198                    dtype: DType::F8E4M3,
199                })?))
200            }
201            None => {
202                let _acquired_quantize_guard = guard.acquire(&device);
203
204                let w = weight.to_device(&device)?;
205                let b = if let Some(b) = &self.bias {
206                    Some(b.to_device(&device)?)
207                } else {
208                    None
209                };
210                Ok(Arc::new(UnquantLinear::new(
211                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
212                )?))
213            }
214        }
215    }
216}
217
218// Serialization structure (same as UnquantLinear):
219//
220// -----------------------
221// UQFF version, u32, little endian
222// -----------------------
223// ISQ type (1 for unquantized), u8, little endian
224// -----------------------
225// Whether bias data is included, u8 boolean
226// -----------------------
227// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
228// -----------------------
229// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
230// -----------------------
231
232impl QuantizedSerde for PerTensorFP8Linear {
233    fn isq_serde_supported(&self) -> bool {
234        true
235    }
236    fn name(&self) -> &'static str {
237        "pertensor-fp8-linear"
238    }
239    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
240        self.serialize_with_bias(self.bias.clone())
241    }
242    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
243        // Serialize as unquantized since weights are already dequantized
244        let mut buffer = Vec::new();
245
246        // Version is always first!
247        buffer.extend(&UQFF_VERSION.to_le_bytes());
248
249        // ISQ type for unquant is 1 (same as UnquantLinear)
250        buffer.push(QuantizedSerdeType::Unquant as u8);
251
252        // Has bias
253        buffer.push(bias.is_some() as u8);
254
255        // Weight (already dequantized)
256        serialize_tensor(&mut buffer, &self.weight)?;
257
258        if let Some(bias) = &bias {
259            // Bias
260            serialize_tensor(&mut buffer, bias)?;
261        }
262
263        Ok(Cow::from(buffer))
264    }
265}
266
267/// Load a per-tensor FP8 linear layer from the VarBuilder.
268///
269/// This handles models with per-tensor FP8 quantization where:
270/// - `weight_block_size` is null (per-tensor, not blockwise)
271/// - Each layer has: weight (FP8), weight_scale_inv (F32), activation_scale (F32)
272pub fn pertensor_fp8_linear_b(
273    in_dim: usize,
274    out_dim: usize,
275    _config: &QuantizedConfig,
276    bias: bool,
277    _hints: Shard,
278    vb: ShardedVarBuilder,
279) -> Result<Arc<dyn QuantMethod>> {
280    // Handle the case where we actually have unquantized weights
281    if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
282        return crate::linear_b(in_dim, out_dim, bias, &None, vb);
283    }
284
285    // Handle the case where the layer is dummy (no tensors)
286    if !vb.contains_tensor("weight") {
287        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
288        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
289    }
290
291    // Load FP8 weight tensor
292    let weight = vb.get_with_hints_dtype(
293        (out_dim, in_dim),
294        "weight",
295        Default::default(),
296        DType::F8E4M3,
297    )?;
298
299    // Load per-tensor weight scale (scalar)
300    let weight_scale_inv =
301        vb.get_with_hints_dtype((), "weight_scale_inv", Default::default(), DType::F32)?;
302
303    // Load activation scale if present (optional - some models may not have it)
304    let activation_scale = if vb.contains_tensor("activation_scale") {
305        Some(vb.get_with_hints_dtype((), "activation_scale", Default::default(), DType::F32)?)
306    } else {
307        None
308    };
309
310    let bias = if bias && vb.contains_tensor("bias") {
311        Some(vb.get((out_dim,), "bias")?)
312    } else {
313        None
314    };
315
316    // Determine the output dtype for dequantization.
317    // We can't use vb.dtype() as that returns F8E4M3 (the storage type).
318    // Use the bias dtype if available, otherwise default to BF16.
319    let dequant_dtype = bias.as_ref().map(|b| b.dtype()).unwrap_or(DType::BF16);
320
321    // Use new() which handles dequantization (Candle FP8 is storage-only)
322    Ok(Arc::new(PerTensorFP8Linear::new(
323        QuantMethodConfig::PerTensorFP8 {
324            weight,
325            weight_scale_inv,
326            activation_scale,
327            bias,
328            dequant_dtype,
329        },
330    )?))
331}