mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy)]
26pub enum AfqBits {
27    Two = 2,
28    Three = 3,
29    Four = 4,
30    Six = 6,
31    Eight = 8,
32    Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36    type Error = candle_core::Error;
37    fn try_from(value: usize) -> Result<Self> {
38        match value {
39            2 => Ok(Self::Two),
40            3 => Ok(Self::Three),
41            4 => Ok(Self::Four),
42            6 => Ok(Self::Six),
43            8 => Ok(Self::Eight),
44            40 => Ok(Self::Mxfp4),
45            x => candle_core::bail!("Invalid AFQ bits {x}."),
46        }
47    }
48}
49
50impl TryFrom<u8> for AfqBits {
51    type Error = candle_core::Error;
52    fn try_from(value: u8) -> Result<Self> {
53        Self::try_from(value as usize)
54    }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default)]
59pub enum AfqGroupSize {
60    Low = 32,
61    #[default]
62    Med = 64,
63    High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67    type Error = candle_core::Error;
68    fn try_from(value: usize) -> Result<Self> {
69        match value {
70            32 => Ok(Self::Low),
71            64 => Ok(Self::Med),
72            128 => Ok(Self::High),
73            x => candle_core::bail!("Invalid AFQ group size {x}."),
74        }
75    }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79    type Error = candle_core::Error;
80    fn try_from(value: u8) -> Result<Self> {
81        Self::try_from(value as usize)
82    }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87    w_q: Tensor,
88    scales: Tensor,
89    biases: Tensor,
90    bias: Option<Tensor>,
91    bits: AfqBits,
92    group_size: AfqGroupSize,
93}
94
95impl QuantMethod for AfqLayer {
96    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
97    where
98        Self: Sized,
99    {
100        match method {
101            QuantMethodConfig::Gguf { .. }
102            | QuantMethodConfig::GptqAwq { .. }
103            | QuantMethodConfig::Hqq { .. }
104            | QuantMethodConfig::Dummy
105            | QuantMethodConfig::FP8 { .. }
106            | QuantMethodConfig::Bnb { .. }
107            | QuantMethodConfig::BlockwiseFP8 { .. }
108            | QuantMethodConfig::Unquantized(_)
109            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
110            QuantMethodConfig::Afq {
111                weight,
112                bias,
113                bits,
114                group_size,
115            } => {
116                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
117
118                Ok(Self {
119                    w_q,
120                    scales,
121                    biases,
122                    bias,
123                    bits,
124                    group_size,
125                })
126            }
127        }
128    }
129
130    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
131        ops::afq_dequantize_op(
132            &self.w_q,
133            &self.scales,
134            &self.biases,
135            self.group_size,
136            self.bits,
137        )
138    }
139
140    fn forward(&self, x: &Tensor) -> Result<Tensor> {
141        ops::afq_mm_op(
142            x,
143            &self.w_q,
144            &self.scales,
145            &self.biases,
146            None,
147            None,
148            self.group_size,
149            self.bits,
150            true,
151        )
152    }
153
154    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
155        ops::afq_mm_op(
156            x,
157            &self.w_q,
158            &self.scales,
159            &self.biases,
160            None,
161            Some(indices),
162            self.group_size,
163            self.bits,
164            true,
165        )
166    }
167
168    fn quantized_act_type(&self) -> Option<DType> {
169        None
170    }
171
172    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
173        let dequant = self.dequantize_w()?;
174        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
175            weight: (dequant + delta)?,
176            bias: self.bias.clone(),
177            bits: self.bits,
178            group_size: self.group_size,
179        })?))
180    }
181
182    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
183        (self.scales.dtype(), self.scales.device().clone())
184    }
185
186    fn apply_isq(
187        self: Arc<Self>,
188        _dtype: Option<IsqType>,
189        _device: Device,
190        _n_quantized: &AtomicUsize,
191        _imatrix_weight: Option<Vec<f32>>,
192        _guard: QuantizeOntoGuard,
193    ) -> Result<Arc<dyn QuantMethod>> {
194        todo!()
195    }
196}
197
198impl AfqLayer {
199    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
200        let mut buffer = Cursor::new(data.to_vec());
201
202        let version = buffer.read_u32::<LittleEndian>()?;
203        if let Err(e) = version_is_compatible(version) {
204            return Err(candle_core::Error::wrap(e));
205        }
206
207        let isq_type = buffer.read_u8()? as usize;
208        if isq_type != QuantizedSerdeType::Afq as usize {
209            candle_core::bail!(
210                "ISQ type ({isq_type}) doesn't match expected type {}",
211                QuantizedSerdeType::Afq as usize
212            );
213        }
214
215        let has_bias = buffer.read_u8()? != 0;
216
217        // Weight, scales, biases
218        fake_deserialize_tensor(&mut buffer)?;
219        fake_deserialize_tensor(&mut buffer)?;
220        fake_deserialize_tensor(&mut buffer)?;
221
222        // Bits and group size
223        let bits: AfqBits = buffer.read_u8()?.try_into()?;
224        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
225
226        if has_bias {
227            fake_deserialize_tensor(&mut buffer)?
228        }
229
230        match bits {
231            AfqBits::Two => Ok(IsqType::AFQ2),
232            AfqBits::Three => Ok(IsqType::AFQ3),
233            AfqBits::Four => Ok(IsqType::AFQ4),
234            AfqBits::Six => Ok(IsqType::AFQ6),
235            AfqBits::Eight => Ok(IsqType::AFQ8),
236            AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
237        }
238    }
239
240    pub fn afq_linear_b(
241        in_dim: usize,
242        out_dim: usize,
243        config: &QuantizedConfig,
244        bias: bool,
245        vb: ShardedVarBuilder,
246    ) -> Result<Arc<dyn QuantMethod>> {
247        let QuantizedConfig::Afq { bits, group_size } = config else {
248            candle_core::bail!("Unexpected quantization config.")
249        };
250
251        let w_q = vb.get_with_hints_dtype(
252            (out_dim, in_dim * bits / 32),
253            "weight",
254            Default::default(),
255            DType::U32,
256        )?;
257        let scales =
258            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
259        let biases =
260            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
261
262        let bias = if bias {
263            Some(vb.get((out_dim,), "bias")?)
264        } else {
265            None
266        };
267
268        Ok(Arc::new(Self {
269            w_q,
270            scales,
271            bias,
272            biases,
273            bits: AfqBits::try_from(*bits)?,
274            group_size: AfqGroupSize::try_from(*group_size)?,
275        }))
276    }
277
278    pub fn afq_packed_linear_b(
279        num_local_experts: usize,
280        in_dim: usize,
281        out_dim: usize,
282        config: &QuantizedConfig,
283        bias: bool,
284        vb: ShardedVarBuilder,
285    ) -> Result<Arc<dyn QuantMethod>> {
286        let QuantizedConfig::Afq { bits, group_size } = config else {
287            candle_core::bail!("Unexpected quantization config.")
288        };
289
290        let w_q = vb.get_with_hints_dtype(
291            (num_local_experts, out_dim, in_dim * bits / 32),
292            "weight",
293            Default::default(),
294            DType::U32,
295        )?;
296        let scales = vb.get_with_hints(
297            (num_local_experts, out_dim, in_dim / group_size),
298            "scales",
299            Default::default(),
300        )?;
301        let biases = vb.get_with_hints(
302            (num_local_experts, out_dim, in_dim / group_size),
303            "biases",
304            Default::default(),
305        )?;
306
307        let bias = if bias {
308            Some(vb.get((num_local_experts, out_dim), "bias")?)
309        } else {
310            None
311        };
312
313        Ok(Arc::new(Self {
314            w_q,
315            scales,
316            bias,
317            biases,
318            bits: AfqBits::try_from(*bits)?,
319            group_size: AfqGroupSize::try_from(*group_size)?,
320        }))
321    }
322}
323
324impl QuantizedSerde for AfqLayer {
325    fn name(&self) -> &'static str {
326        "afq-layer"
327    }
328    fn isq_serde_supported(&self) -> bool {
329        true
330    }
331    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
332        self.serialize_with_bias(self.bias.clone())
333    }
334    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
335        let mut buffer = Vec::new();
336
337        // Version is always first!
338        buffer.extend(&UQFF_VERSION.to_le_bytes());
339
340        // ISQ type for afq is 4
341        buffer.push(QuantizedSerdeType::Afq as u8);
342
343        // Has bias
344        buffer.push(bias.is_some() as u8);
345
346        // Weight, scales, biases
347        serialize_tensor(&mut buffer, &self.w_q)?;
348        serialize_tensor(&mut buffer, &self.scales)?;
349        serialize_tensor(&mut buffer, &self.biases)?;
350
351        // Bits and group size
352        buffer.push(self.bits as u8);
353        buffer.push(self.group_size as u8);
354
355        if let Some(bias) = &bias {
356            // Bias
357            serialize_tensor(&mut buffer, bias)?;
358        }
359
360        Ok(Cow::from(buffer))
361    }
362    fn deserialize(
363        data: Cow<[u8]>,
364        device: &Device,
365        _comm: &Arc<Comm>,
366        guard: QuantizeOntoGuard,
367    ) -> Result<Arc<dyn QuantMethod>>
368    where
369        Self: Sized,
370    {
371        let mut buffer = Cursor::new(data);
372
373        let version = buffer.read_u32::<LittleEndian>()?;
374        if let Err(e) = version_is_compatible(version) {
375            return Err(candle_core::Error::wrap(e));
376        }
377
378        let isq_type = buffer.read_u8()? as usize;
379        if isq_type != QuantizedSerdeType::Afq as usize {
380            candle_core::bail!(
381                "ISQ type ({isq_type}) doesn't match expected type {}",
382                QuantizedSerdeType::Afq as usize
383            );
384        }
385
386        let has_bias = buffer.read_u8()? != 0;
387
388        let _acquired_load_guard = guard.acquire(device);
389        // Weight, scales, biases
390        let w_q = deserialize_tensor(&mut buffer, device)?;
391        let scales = deserialize_tensor(&mut buffer, device)?;
392        let biases = deserialize_tensor(&mut buffer, device)?;
393
394        // Bits and group size
395        let bits: AfqBits = buffer.read_u8()?.try_into()?;
396        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
397
398        let b = if has_bias {
399            Some(deserialize_tensor(&mut buffer, device)?)
400        } else {
401            None
402        };
403
404        Ok(Arc::new(Self {
405            w_q,
406            scales,
407            bias: b,
408            biases,
409            bits,
410            group_size,
411        }))
412    }
413    fn deserialize_ext_bias(
414        data: Cow<[u8]>,
415        device: &Device,
416        guard: QuantizeOntoGuard,
417    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
418    where
419        Self: Sized,
420    {
421        let mut buffer = Cursor::new(data);
422
423        let version = buffer.read_u32::<LittleEndian>()?;
424        if let Err(e) = version_is_compatible(version) {
425            return Err(candle_core::Error::wrap(e));
426        }
427
428        let isq_type = buffer.read_u8()? as usize;
429        if isq_type != QuantizedSerdeType::Afq as usize {
430            candle_core::bail!(
431                "ISQ type ({isq_type}) doesn't match expected type {}",
432                QuantizedSerdeType::Afq as usize
433            );
434        }
435
436        let has_bias = buffer.read_u8()? != 0;
437
438        let _acquired_load_guard = guard.acquire(device);
439        // Weight, scales, biases
440        let w_q = deserialize_tensor(&mut buffer, device)?;
441        let scales = deserialize_tensor(&mut buffer, device)?;
442        let biases = deserialize_tensor(&mut buffer, device)?;
443
444        // Bits and group size
445        let bits: AfqBits = buffer.read_u8()?.try_into()?;
446        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
447
448        let b = if has_bias {
449            Some(deserialize_tensor(&mut buffer, device)?)
450        } else {
451            None
452        };
453
454        Ok((
455            Arc::new(Self {
456                w_q,
457                scales,
458                bias: None,
459                biases,
460                bits,
461                group_size,
462            }),
463            b,
464        ))
465    }
466}