mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24    Two = 2,
25    Three = 3,
26    Four = 4,
27    Six = 6,
28    Eight = 8,
29    Mxfp4 = 40,
30}
31
32impl TryFrom<usize> for AfqBits {
33    type Error = candle_core::Error;
34    fn try_from(value: usize) -> Result<Self> {
35        match value {
36            2 => Ok(Self::Two),
37            3 => Ok(Self::Three),
38            4 => Ok(Self::Four),
39            6 => Ok(Self::Six),
40            8 => Ok(Self::Eight),
41            40 => Ok(Self::Mxfp4),
42            x => candle_core::bail!("Invalid AFQ bits {x}."),
43        }
44    }
45}
46
47impl TryFrom<u8> for AfqBits {
48    type Error = candle_core::Error;
49    fn try_from(value: u8) -> Result<Self> {
50        Self::try_from(value as usize)
51    }
52}
53
54#[repr(u8)]
55#[derive(Debug, Clone, Copy, Default)]
56pub enum AfqGroupSize {
57    Low = 32,
58    #[default]
59    Med = 64,
60    High = 128,
61}
62
63impl TryFrom<usize> for AfqGroupSize {
64    type Error = candle_core::Error;
65    fn try_from(value: usize) -> Result<Self> {
66        match value {
67            32 => Ok(Self::Low),
68            64 => Ok(Self::Med),
69            128 => Ok(Self::High),
70            x => candle_core::bail!("Invalid AFQ group size {x}."),
71        }
72    }
73}
74
75impl TryFrom<u8> for AfqGroupSize {
76    type Error = candle_core::Error;
77    fn try_from(value: u8) -> Result<Self> {
78        Self::try_from(value as usize)
79    }
80}
81
82#[derive(Debug)]
83pub struct AfqLayer {
84    w_q: Tensor,
85    scales: Tensor,
86    biases: Tensor,
87    bias: Option<Tensor>,
88    bits: AfqBits,
89    group_size: AfqGroupSize,
90}
91
92impl QuantMethod for AfqLayer {
93    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
94    where
95        Self: Sized,
96    {
97        match method {
98            QuantMethodConfig::Gguf { .. }
99            | QuantMethodConfig::GptqAwq { .. }
100            | QuantMethodConfig::Hqq { .. }
101            | QuantMethodConfig::Dummy
102            | QuantMethodConfig::FP8 { .. }
103            | QuantMethodConfig::Bnb { .. }
104            | QuantMethodConfig::BlockwiseFP8 { .. }
105            | QuantMethodConfig::Unquantized(_)
106            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
107            QuantMethodConfig::Afq {
108                weight,
109                bias,
110                bits,
111                group_size,
112            } => {
113                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
114
115                Ok(Self {
116                    w_q,
117                    scales,
118                    biases,
119                    bias,
120                    bits,
121                    group_size,
122                })
123            }
124        }
125    }
126
127    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
128        ops::afq_dequantize_op(
129            &self.w_q,
130            &self.scales,
131            &self.biases,
132            self.group_size,
133            self.bits,
134        )
135    }
136
137    fn forward(&self, x: &Tensor) -> Result<Tensor> {
138        ops::afq_mm_op(
139            x,
140            &self.w_q,
141            &self.scales,
142            &self.biases,
143            None,
144            None,
145            self.group_size,
146            self.bits,
147            true,
148        )
149    }
150
151    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
152        ops::afq_mm_op(
153            x,
154            &self.w_q,
155            &self.scales,
156            &self.biases,
157            None,
158            Some(indices),
159            self.group_size,
160            self.bits,
161            true,
162        )
163    }
164
165    fn quantized_act_type(&self) -> Option<DType> {
166        None
167    }
168
169    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
170        let dequant = self.dequantize_w()?;
171        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
172            weight: (dequant + delta)?,
173            bias: self.bias.clone(),
174            bits: self.bits,
175            group_size: self.group_size,
176        })?))
177    }
178
179    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
180        (self.scales.dtype(), self.scales.device().clone())
181    }
182
183    fn apply_isq(
184        self: Arc<Self>,
185        _dtype: Option<IsqType>,
186        _device: Device,
187        _n_quantized: &AtomicUsize,
188        _imatrix_weight: Option<Vec<f32>>,
189        _guard: QuantizeOntoGuard,
190    ) -> Result<Arc<dyn QuantMethod>> {
191        todo!()
192    }
193}
194
195impl AfqLayer {
196    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
197        let mut buffer = Cursor::new(data.to_vec());
198
199        let version = buffer.read_u32::<LittleEndian>()?;
200        if let Err(e) = version_is_compatible(version) {
201            return Err(candle_core::Error::wrap(e));
202        }
203
204        let isq_type = buffer.read_u8()? as usize;
205        if isq_type != QuantizedSerdeType::Afq as usize {
206            candle_core::bail!(
207                "ISQ type ({isq_type}) doesn't match expected type {}",
208                QuantizedSerdeType::Afq as usize
209            );
210        }
211
212        let has_bias = buffer.read_u8()? != 0;
213
214        // Weight, scales, biases
215        fake_deserialize_tensor(&mut buffer)?;
216        fake_deserialize_tensor(&mut buffer)?;
217        fake_deserialize_tensor(&mut buffer)?;
218
219        // Bits and group size
220        let bits: AfqBits = buffer.read_u8()?.try_into()?;
221        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
222
223        if has_bias {
224            fake_deserialize_tensor(&mut buffer)?
225        }
226
227        match bits {
228            AfqBits::Two => Ok(IsqType::AFQ2),
229            AfqBits::Three => Ok(IsqType::AFQ3),
230            AfqBits::Four => Ok(IsqType::AFQ4),
231            AfqBits::Six => Ok(IsqType::AFQ6),
232            AfqBits::Eight => Ok(IsqType::AFQ8),
233            AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
234        }
235    }
236
237    pub fn afq_linear_b(
238        in_dim: usize,
239        out_dim: usize,
240        config: &QuantizedConfig,
241        bias: bool,
242        vb: ShardedVarBuilder,
243    ) -> Result<Arc<dyn QuantMethod>> {
244        let QuantizedConfig::Afq { bits, group_size } = config else {
245            candle_core::bail!("Unexpected quantization config.")
246        };
247
248        let w_q = vb.get_with_hints_dtype(
249            (out_dim, in_dim * bits / 32),
250            "weight",
251            Default::default(),
252            DType::U32,
253        )?;
254        let scales =
255            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
256        let biases =
257            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
258
259        let bias = if bias {
260            Some(vb.get((out_dim,), "bias")?)
261        } else {
262            None
263        };
264
265        Ok(Arc::new(Self {
266            w_q,
267            scales,
268            bias,
269            biases,
270            bits: AfqBits::try_from(*bits)?,
271            group_size: AfqGroupSize::try_from(*group_size)?,
272        }))
273    }
274
275    pub fn afq_packed_linear_b(
276        num_local_experts: usize,
277        in_dim: usize,
278        out_dim: usize,
279        config: &QuantizedConfig,
280        bias: bool,
281        vb: ShardedVarBuilder,
282    ) -> Result<Arc<dyn QuantMethod>> {
283        let QuantizedConfig::Afq { bits, group_size } = config else {
284            candle_core::bail!("Unexpected quantization config.")
285        };
286
287        let w_q = vb.get_with_hints_dtype(
288            (num_local_experts, out_dim, in_dim * bits / 32),
289            "weight",
290            Default::default(),
291            DType::U32,
292        )?;
293        let scales = vb.get_with_hints(
294            (num_local_experts, out_dim, in_dim / group_size),
295            "scales",
296            Default::default(),
297        )?;
298        let biases = vb.get_with_hints(
299            (num_local_experts, out_dim, in_dim / group_size),
300            "biases",
301            Default::default(),
302        )?;
303
304        let bias = if bias {
305            Some(vb.get((num_local_experts, out_dim), "bias")?)
306        } else {
307            None
308        };
309
310        Ok(Arc::new(Self {
311            w_q,
312            scales,
313            bias,
314            biases,
315            bits: AfqBits::try_from(*bits)?,
316            group_size: AfqGroupSize::try_from(*group_size)?,
317        }))
318    }
319}
320
321impl QuantizedSerde for AfqLayer {
322    fn name(&self) -> &'static str {
323        "afq-layer"
324    }
325    fn isq_serde_supported(&self) -> bool {
326        true
327    }
328    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
329        self.serialize_with_bias(self.bias.clone())
330    }
331    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
332        let mut buffer = Vec::new();
333
334        // Version is always first!
335        buffer.extend(&UQFF_VERSION.to_le_bytes());
336
337        // ISQ type for afq is 4
338        buffer.push(QuantizedSerdeType::Afq as u8);
339
340        // Has bias
341        buffer.push(bias.is_some() as u8);
342
343        // Weight, scales, biases
344        serialize_tensor(&mut buffer, &self.w_q)?;
345        serialize_tensor(&mut buffer, &self.scales)?;
346        serialize_tensor(&mut buffer, &self.biases)?;
347
348        // Bits and group size
349        buffer.push(self.bits as u8);
350        buffer.push(self.group_size as u8);
351
352        if let Some(bias) = &bias {
353            // Bias
354            serialize_tensor(&mut buffer, bias)?;
355        }
356
357        Ok(Cow::from(buffer))
358    }
359    fn deserialize(
360        data: Cow<[u8]>,
361        device: &Device,
362        _comm: &Arc<Comm>,
363        guard: QuantizeOntoGuard,
364    ) -> Result<Arc<dyn QuantMethod>>
365    where
366        Self: Sized,
367    {
368        let mut buffer = Cursor::new(data);
369
370        let version = buffer.read_u32::<LittleEndian>()?;
371        if let Err(e) = version_is_compatible(version) {
372            return Err(candle_core::Error::wrap(e));
373        }
374
375        let isq_type = buffer.read_u8()? as usize;
376        if isq_type != QuantizedSerdeType::Afq as usize {
377            candle_core::bail!(
378                "ISQ type ({isq_type}) doesn't match expected type {}",
379                QuantizedSerdeType::Afq as usize
380            );
381        }
382
383        let has_bias = buffer.read_u8()? != 0;
384
385        let _acquired_load_guard = guard.acquire(device);
386        // Weight, scales, biases
387        let w_q = deserialize_tensor(&mut buffer, device)?;
388        let scales = deserialize_tensor(&mut buffer, device)?;
389        let biases = deserialize_tensor(&mut buffer, device)?;
390
391        // Bits and group size
392        let bits: AfqBits = buffer.read_u8()?.try_into()?;
393        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
394
395        let b = if has_bias {
396            Some(deserialize_tensor(&mut buffer, device)?)
397        } else {
398            None
399        };
400
401        Ok(Arc::new(Self {
402            w_q,
403            scales,
404            bias: b,
405            biases,
406            bits,
407            group_size,
408        }))
409    }
410    fn deserialize_ext_bias(
411        data: Cow<[u8]>,
412        device: &Device,
413        guard: QuantizeOntoGuard,
414    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
415    where
416        Self: Sized,
417    {
418        let mut buffer = Cursor::new(data);
419
420        let version = buffer.read_u32::<LittleEndian>()?;
421        if let Err(e) = version_is_compatible(version) {
422            return Err(candle_core::Error::wrap(e));
423        }
424
425        let isq_type = buffer.read_u8()? as usize;
426        if isq_type != QuantizedSerdeType::Afq as usize {
427            candle_core::bail!(
428                "ISQ type ({isq_type}) doesn't match expected type {}",
429                QuantizedSerdeType::Afq as usize
430            );
431        }
432
433        let has_bias = buffer.read_u8()? != 0;
434
435        let _acquired_load_guard = guard.acquire(device);
436        // Weight, scales, biases
437        let w_q = deserialize_tensor(&mut buffer, device)?;
438        let scales = deserialize_tensor(&mut buffer, device)?;
439        let biases = deserialize_tensor(&mut buffer, device)?;
440
441        // Bits and group size
442        let bits: AfqBits = buffer.read_u8()?.try_into()?;
443        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
444
445        let b = if has_bias {
446            Some(deserialize_tensor(&mut buffer, device)?)
447        } else {
448            None
449        };
450
451        Ok((
452            Arc::new(Self {
453                w_q,
454                scales,
455                bias: None,
456                biases,
457                bits,
458                group_size,
459            }),
460            b,
461        ))
462    }
463}