mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy)]
26pub enum AfqBits {
27    Two = 2,
28    Three = 3,
29    Four = 4,
30    Six = 6,
31    Eight = 8,
32    Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36    type Error = candle_core::Error;
37    fn try_from(value: usize) -> Result<Self> {
38        match value {
39            2 => Ok(Self::Two),
40            3 => Ok(Self::Three),
41            4 => Ok(Self::Four),
42            6 => Ok(Self::Six),
43            8 => Ok(Self::Eight),
44            40 => Ok(Self::Mxfp4),
45            x => candle_core::bail!("Invalid AFQ bits {x}."),
46        }
47    }
48}
49
50impl TryFrom<u8> for AfqBits {
51    type Error = candle_core::Error;
52    fn try_from(value: u8) -> Result<Self> {
53        Self::try_from(value as usize)
54    }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default)]
59pub enum AfqGroupSize {
60    Low = 32,
61    #[default]
62    Med = 64,
63    High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67    type Error = candle_core::Error;
68    fn try_from(value: usize) -> Result<Self> {
69        match value {
70            32 => Ok(Self::Low),
71            64 => Ok(Self::Med),
72            128 => Ok(Self::High),
73            x => candle_core::bail!("Invalid AFQ group size {x}."),
74        }
75    }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79    type Error = candle_core::Error;
80    fn try_from(value: u8) -> Result<Self> {
81        Self::try_from(value as usize)
82    }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87    w_q: Tensor,
88    scales: Tensor,
89    biases: Tensor,
90    bias: Option<Tensor>,
91    bits: AfqBits,
92    group_size: AfqGroupSize,
93}
94
95impl QuantMethod for AfqLayer {
96    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
97    where
98        Self: Sized,
99    {
100        match method {
101            QuantMethodConfig::Gguf { .. }
102            | QuantMethodConfig::GptqAwq { .. }
103            | QuantMethodConfig::Hqq { .. }
104            | QuantMethodConfig::Dummy
105            | QuantMethodConfig::FP8 { .. }
106            | QuantMethodConfig::Bnb { .. }
107            | QuantMethodConfig::BlockwiseFP8 { .. }
108            | QuantMethodConfig::PerTensorFP8 { .. }
109            | QuantMethodConfig::Unquantized(_)
110            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
111            QuantMethodConfig::Afq {
112                weight,
113                bias,
114                bits,
115                group_size,
116            } => {
117                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
118
119                Ok(Self {
120                    w_q,
121                    scales,
122                    biases,
123                    bias,
124                    bits,
125                    group_size,
126                })
127            }
128        }
129    }
130
131    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
132        ops::afq_dequantize_op(
133            &self.w_q,
134            &self.scales,
135            &self.biases,
136            self.group_size,
137            self.bits,
138        )
139    }
140
141    fn forward(&self, x: &Tensor) -> Result<Tensor> {
142        ops::afq_mm_op(
143            x,
144            &self.w_q,
145            &self.scales,
146            &self.biases,
147            None,
148            None,
149            self.group_size,
150            self.bits,
151            true,
152        )
153    }
154
155    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
156        ops::afq_mm_op(
157            x,
158            &self.w_q,
159            &self.scales,
160            &self.biases,
161            None,
162            Some(indices),
163            self.group_size,
164            self.bits,
165            true,
166        )
167    }
168
169    fn quantized_act_type(&self) -> Option<DType> {
170        None
171    }
172
173    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
174        let dequant = self.dequantize_w()?;
175        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
176            weight: (dequant + delta)?,
177            bias: self.bias.clone(),
178            bits: self.bits,
179            group_size: self.group_size,
180        })?))
181    }
182
183    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
184        (self.scales.dtype(), self.scales.device().clone())
185    }
186
187    fn apply_isq(
188        self: Arc<Self>,
189        _dtype: Option<IsqType>,
190        _device: Device,
191        _n_quantized: &AtomicUsize,
192        _imatrix_weight: Option<Vec<f32>>,
193        _guard: QuantizeOntoGuard,
194    ) -> Result<Arc<dyn QuantMethod>> {
195        todo!()
196    }
197}
198
199impl AfqLayer {
200    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
201        let mut buffer = Cursor::new(data.to_vec());
202
203        let version = buffer.read_u32::<LittleEndian>()?;
204        if let Err(e) = version_is_compatible(version) {
205            return Err(candle_core::Error::wrap(e));
206        }
207
208        let isq_type = buffer.read_u8()? as usize;
209        if isq_type != QuantizedSerdeType::Afq as usize {
210            candle_core::bail!(
211                "ISQ type ({isq_type}) doesn't match expected type {}",
212                QuantizedSerdeType::Afq as usize
213            );
214        }
215
216        let has_bias = buffer.read_u8()? != 0;
217
218        // Weight, scales, biases
219        fake_deserialize_tensor(&mut buffer)?;
220        fake_deserialize_tensor(&mut buffer)?;
221        fake_deserialize_tensor(&mut buffer)?;
222
223        // Bits and group size
224        let bits: AfqBits = buffer.read_u8()?.try_into()?;
225        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
226
227        if has_bias {
228            fake_deserialize_tensor(&mut buffer)?
229        }
230
231        match bits {
232            AfqBits::Two => Ok(IsqType::AFQ2),
233            AfqBits::Three => Ok(IsqType::AFQ3),
234            AfqBits::Four => Ok(IsqType::AFQ4),
235            AfqBits::Six => Ok(IsqType::AFQ6),
236            AfqBits::Eight => Ok(IsqType::AFQ8),
237            AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
238        }
239    }
240
241    pub fn afq_linear_b(
242        in_dim: usize,
243        out_dim: usize,
244        config: &QuantizedConfig,
245        bias: bool,
246        vb: ShardedVarBuilder,
247    ) -> Result<Arc<dyn QuantMethod>> {
248        let QuantizedConfig::Afq { bits, group_size } = config else {
249            candle_core::bail!("Unexpected quantization config.")
250        };
251
252        let w_q = vb.get_with_hints_dtype(
253            (out_dim, in_dim * bits / 32),
254            "weight",
255            Default::default(),
256            DType::U32,
257        )?;
258        let scales =
259            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
260        let biases =
261            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
262
263        let bias = if bias {
264            Some(vb.get((out_dim,), "bias")?)
265        } else {
266            None
267        };
268
269        Ok(Arc::new(Self {
270            w_q,
271            scales,
272            bias,
273            biases,
274            bits: AfqBits::try_from(*bits)?,
275            group_size: AfqGroupSize::try_from(*group_size)?,
276        }))
277    }
278
279    pub fn afq_packed_linear_b(
280        num_local_experts: usize,
281        in_dim: usize,
282        out_dim: usize,
283        config: &QuantizedConfig,
284        bias: bool,
285        vb: ShardedVarBuilder,
286    ) -> Result<Arc<dyn QuantMethod>> {
287        let QuantizedConfig::Afq { bits, group_size } = config else {
288            candle_core::bail!("Unexpected quantization config.")
289        };
290
291        let w_q = vb.get_with_hints_dtype(
292            (num_local_experts, out_dim, in_dim * bits / 32),
293            "weight",
294            Default::default(),
295            DType::U32,
296        )?;
297        let scales = vb.get_with_hints(
298            (num_local_experts, out_dim, in_dim / group_size),
299            "scales",
300            Default::default(),
301        )?;
302        let biases = vb.get_with_hints(
303            (num_local_experts, out_dim, in_dim / group_size),
304            "biases",
305            Default::default(),
306        )?;
307
308        let bias = if bias {
309            Some(vb.get((num_local_experts, out_dim), "bias")?)
310        } else {
311            None
312        };
313
314        Ok(Arc::new(Self {
315            w_q,
316            scales,
317            bias,
318            biases,
319            bits: AfqBits::try_from(*bits)?,
320            group_size: AfqGroupSize::try_from(*group_size)?,
321        }))
322    }
323}
324
325impl QuantizedSerde for AfqLayer {
326    fn name(&self) -> &'static str {
327        "afq-layer"
328    }
329    fn isq_serde_supported(&self) -> bool {
330        true
331    }
332    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
333        self.serialize_with_bias(self.bias.clone())
334    }
335    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
336        let mut buffer = Vec::new();
337
338        // Version is always first!
339        buffer.extend(&UQFF_VERSION.to_le_bytes());
340
341        // ISQ type for afq is 4
342        buffer.push(QuantizedSerdeType::Afq as u8);
343
344        // Has bias
345        buffer.push(bias.is_some() as u8);
346
347        // Weight, scales, biases
348        serialize_tensor(&mut buffer, &self.w_q)?;
349        serialize_tensor(&mut buffer, &self.scales)?;
350        serialize_tensor(&mut buffer, &self.biases)?;
351
352        // Bits and group size
353        buffer.push(self.bits as u8);
354        buffer.push(self.group_size as u8);
355
356        if let Some(bias) = &bias {
357            // Bias
358            serialize_tensor(&mut buffer, bias)?;
359        }
360
361        Ok(Cow::from(buffer))
362    }
363    fn deserialize(
364        data: Cow<[u8]>,
365        device: &Device,
366        _comm: &Arc<Comm>,
367        guard: QuantizeOntoGuard,
368    ) -> Result<Arc<dyn QuantMethod>>
369    where
370        Self: Sized,
371    {
372        let mut buffer = Cursor::new(data);
373
374        let version = buffer.read_u32::<LittleEndian>()?;
375        if let Err(e) = version_is_compatible(version) {
376            return Err(candle_core::Error::wrap(e));
377        }
378
379        let isq_type = buffer.read_u8()? as usize;
380        if isq_type != QuantizedSerdeType::Afq as usize {
381            candle_core::bail!(
382                "ISQ type ({isq_type}) doesn't match expected type {}",
383                QuantizedSerdeType::Afq as usize
384            );
385        }
386
387        let has_bias = buffer.read_u8()? != 0;
388
389        let _acquired_load_guard = guard.acquire(device);
390        // Weight, scales, biases
391        let w_q = deserialize_tensor(&mut buffer, device)?;
392        let scales = deserialize_tensor(&mut buffer, device)?;
393        let biases = deserialize_tensor(&mut buffer, device)?;
394
395        // Bits and group size
396        let bits: AfqBits = buffer.read_u8()?.try_into()?;
397        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
398
399        let b = if has_bias {
400            Some(deserialize_tensor(&mut buffer, device)?)
401        } else {
402            None
403        };
404
405        Ok(Arc::new(Self {
406            w_q,
407            scales,
408            bias: b,
409            biases,
410            bits,
411            group_size,
412        }))
413    }
414    fn deserialize_ext_bias(
415        data: Cow<[u8]>,
416        device: &Device,
417        guard: QuantizeOntoGuard,
418    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
419    where
420        Self: Sized,
421    {
422        let mut buffer = Cursor::new(data);
423
424        let version = buffer.read_u32::<LittleEndian>()?;
425        if let Err(e) = version_is_compatible(version) {
426            return Err(candle_core::Error::wrap(e));
427        }
428
429        let isq_type = buffer.read_u8()? as usize;
430        if isq_type != QuantizedSerdeType::Afq as usize {
431            candle_core::bail!(
432                "ISQ type ({isq_type}) doesn't match expected type {}",
433                QuantizedSerdeType::Afq as usize
434            );
435        }
436
437        let has_bias = buffer.read_u8()? != 0;
438
439        let _acquired_load_guard = guard.acquire(device);
440        // Weight, scales, biases
441        let w_q = deserialize_tensor(&mut buffer, device)?;
442        let scales = deserialize_tensor(&mut buffer, device)?;
443        let biases = deserialize_tensor(&mut buffer, device)?;
444
445        // Bits and group size
446        let bits: AfqBits = buffer.read_u8()?.try_into()?;
447        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
448
449        let b = if has_bias {
450            Some(deserialize_tensor(&mut buffer, device)?)
451        } else {
452            None
453        };
454
455        Ok((
456            Arc::new(Self {
457                w_q,
458                scales,
459                bias: None,
460                biases,
461                bits,
462                group_size,
463            }),
464            b,
465        ))
466    }
467}