mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24    Two = 2,
25    Three = 3,
26    Four = 4,
27    Six = 6,
28    Eight = 8,
29}
30
31impl TryFrom<usize> for AfqBits {
32    type Error = candle_core::Error;
33    fn try_from(value: usize) -> Result<Self> {
34        match value {
35            2 => Ok(Self::Two),
36            3 => Ok(Self::Three),
37            4 => Ok(Self::Four),
38            6 => Ok(Self::Six),
39            8 => Ok(Self::Eight),
40            x => candle_core::bail!("Invalid AFQ bits {x}."),
41        }
42    }
43}
44
45impl TryFrom<u8> for AfqBits {
46    type Error = candle_core::Error;
47    fn try_from(value: u8) -> Result<Self> {
48        Self::try_from(value as usize)
49    }
50}
51
52#[repr(u8)]
53#[derive(Debug, Clone, Copy, Default)]
54pub enum AfqGroupSize {
55    Low = 32,
56    #[default]
57    Med = 64,
58    High = 128,
59}
60
61impl TryFrom<usize> for AfqGroupSize {
62    type Error = candle_core::Error;
63    fn try_from(value: usize) -> Result<Self> {
64        match value {
65            32 => Ok(Self::Low),
66            64 => Ok(Self::Med),
67            128 => Ok(Self::High),
68            x => candle_core::bail!("Invalid AFQ group size {x}."),
69        }
70    }
71}
72
73impl TryFrom<u8> for AfqGroupSize {
74    type Error = candle_core::Error;
75    fn try_from(value: u8) -> Result<Self> {
76        Self::try_from(value as usize)
77    }
78}
79
80#[derive(Debug)]
81pub struct AfqLayer {
82    w_q: Tensor,
83    scales: Tensor,
84    biases: Tensor,
85    bias: Option<Tensor>,
86    bits: AfqBits,
87    group_size: AfqGroupSize,
88}
89
90impl QuantMethod for AfqLayer {
91    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
92    where
93        Self: Sized,
94    {
95        match method {
96            QuantMethodConfig::Gguf { .. }
97            | QuantMethodConfig::GptqAwq { .. }
98            | QuantMethodConfig::Hqq { .. }
99            | QuantMethodConfig::Dummy
100            | QuantMethodConfig::FP8 { .. }
101            | QuantMethodConfig::Bnb { .. }
102            | QuantMethodConfig::BlockwiseFP8 { .. }
103            | QuantMethodConfig::Unquantized(_) => unreachable!(),
104            QuantMethodConfig::Afq {
105                weight,
106                bias,
107                bits,
108                group_size,
109            } => {
110                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
111
112                Ok(Self {
113                    w_q,
114                    scales,
115                    biases,
116                    bias,
117                    bits,
118                    group_size,
119                })
120            }
121        }
122    }
123
124    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
125        ops::afq_dequantize_op(
126            &self.w_q,
127            &self.scales,
128            &self.biases,
129            self.group_size,
130            self.bits,
131        )
132    }
133
134    fn forward(&self, x: &Tensor) -> Result<Tensor> {
135        ops::afq_mm_op(
136            x,
137            &self.w_q,
138            &self.scales,
139            &self.biases,
140            None,
141            None,
142            self.group_size,
143            self.bits,
144            true,
145        )
146    }
147
148    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
149        ops::afq_mm_op(
150            x,
151            &self.w_q,
152            &self.scales,
153            &self.biases,
154            None,
155            Some(indices),
156            self.group_size,
157            self.bits,
158            true,
159        )
160    }
161
162    fn quantized_act_type(&self) -> Option<DType> {
163        None
164    }
165
166    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
167        let dequant = self.dequantize_w()?;
168        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
169            weight: (dequant + delta)?,
170            bias: self.bias.clone(),
171            bits: self.bits,
172            group_size: self.group_size,
173        })?))
174    }
175
176    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
177        (self.scales.dtype(), self.scales.device().clone())
178    }
179
180    fn apply_isq(
181        self: Arc<Self>,
182        _dtype: Option<IsqType>,
183        _device: Device,
184        _n_quantized: &AtomicUsize,
185        _imatrix_weight: Option<Vec<f32>>,
186        _guard: QuantizeOntoGuard,
187    ) -> Result<Arc<dyn QuantMethod>> {
188        todo!()
189    }
190}
191
192impl AfqLayer {
193    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
194        let mut buffer = Cursor::new(data.to_vec());
195
196        let version = buffer.read_u32::<LittleEndian>()?;
197        if let Err(e) = version_is_compatible(version) {
198            return Err(candle_core::Error::wrap(e));
199        }
200
201        let isq_type = buffer.read_u8()? as usize;
202        if isq_type != QuantizedSerdeType::Afq as usize {
203            candle_core::bail!(
204                "ISQ type ({isq_type}) doesn't match expected type {}",
205                QuantizedSerdeType::Afq as usize
206            );
207        }
208
209        let has_bias = buffer.read_u8()? != 0;
210
211        // Weight, scales, biases
212        fake_deserialize_tensor(&mut buffer)?;
213        fake_deserialize_tensor(&mut buffer)?;
214        fake_deserialize_tensor(&mut buffer)?;
215
216        // Bits and group size
217        let bits: AfqBits = buffer.read_u8()?.try_into()?;
218        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
219
220        if has_bias {
221            fake_deserialize_tensor(&mut buffer)?
222        }
223
224        match bits {
225            AfqBits::Two => Ok(IsqType::AFQ2),
226            AfqBits::Three => Ok(IsqType::AFQ3),
227            AfqBits::Four => Ok(IsqType::AFQ4),
228            AfqBits::Six => Ok(IsqType::AFQ6),
229            AfqBits::Eight => Ok(IsqType::AFQ8),
230        }
231    }
232
233    pub fn afq_linear_b(
234        in_dim: usize,
235        out_dim: usize,
236        config: &QuantizedConfig,
237        bias: bool,
238        vb: ShardedVarBuilder,
239    ) -> Result<Arc<dyn QuantMethod>> {
240        let QuantizedConfig::Afq { bits, group_size } = config else {
241            candle_core::bail!("Unexpected quantization config.")
242        };
243
244        let w_q = vb.get_with_hints_dtype(
245            (out_dim, in_dim * bits / 32),
246            "weight",
247            Default::default(),
248            DType::U32,
249        )?;
250        let scales =
251            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
252        let biases =
253            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
254
255        let bias = if bias {
256            Some(vb.get((out_dim,), "bias")?)
257        } else {
258            None
259        };
260
261        Ok(Arc::new(Self {
262            w_q,
263            scales,
264            bias,
265            biases,
266            bits: AfqBits::try_from(*bits)?,
267            group_size: AfqGroupSize::try_from(*group_size)?,
268        }))
269    }
270
271    pub fn afq_packed_linear_b(
272        num_local_experts: usize,
273        in_dim: usize,
274        out_dim: usize,
275        config: &QuantizedConfig,
276        bias: bool,
277        vb: ShardedVarBuilder,
278    ) -> Result<Arc<dyn QuantMethod>> {
279        let QuantizedConfig::Afq { bits, group_size } = config else {
280            candle_core::bail!("Unexpected quantization config.")
281        };
282
283        let w_q = vb.get_with_hints_dtype(
284            (num_local_experts, out_dim, in_dim * bits / 32),
285            "weight",
286            Default::default(),
287            DType::U32,
288        )?;
289        let scales = vb.get_with_hints(
290            (num_local_experts, out_dim, in_dim / group_size),
291            "scales",
292            Default::default(),
293        )?;
294        let biases = vb.get_with_hints(
295            (num_local_experts, out_dim, in_dim / group_size),
296            "biases",
297            Default::default(),
298        )?;
299
300        let bias = if bias {
301            Some(vb.get((num_local_experts, out_dim), "bias")?)
302        } else {
303            None
304        };
305
306        Ok(Arc::new(Self {
307            w_q,
308            scales,
309            bias,
310            biases,
311            bits: AfqBits::try_from(*bits)?,
312            group_size: AfqGroupSize::try_from(*group_size)?,
313        }))
314    }
315}
316
317impl QuantizedSerde for AfqLayer {
318    fn name(&self) -> &'static str {
319        "afq-layer"
320    }
321    fn isq_serde_supported(&self) -> bool {
322        true
323    }
324    fn serialize(&self) -> Result<Cow<[u8]>> {
325        self.serialize_with_bias(self.bias.clone())
326    }
327    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
328        let mut buffer = Vec::new();
329
330        // Version is always first!
331        buffer.extend(&UQFF_VERSION.to_le_bytes());
332
333        // ISQ type for afq is 4
334        buffer.push(QuantizedSerdeType::Afq as u8);
335
336        // Has bias
337        buffer.push(bias.is_some() as u8);
338
339        // Weight, scales, biases
340        serialize_tensor(&mut buffer, &self.w_q)?;
341        serialize_tensor(&mut buffer, &self.scales)?;
342        serialize_tensor(&mut buffer, &self.biases)?;
343
344        // Bits and group size
345        buffer.push(self.bits as u8);
346        buffer.push(self.group_size as u8);
347
348        if let Some(bias) = &bias {
349            // Bias
350            serialize_tensor(&mut buffer, bias)?;
351        }
352
353        Ok(Cow::from(buffer))
354    }
355    fn deserialize(
356        data: Cow<[u8]>,
357        device: &Device,
358        _comm: &Arc<Comm>,
359        guard: QuantizeOntoGuard,
360    ) -> Result<Arc<dyn QuantMethod>>
361    where
362        Self: Sized,
363    {
364        let mut buffer = Cursor::new(data.to_vec());
365
366        let version = buffer.read_u32::<LittleEndian>()?;
367        if let Err(e) = version_is_compatible(version) {
368            return Err(candle_core::Error::wrap(e));
369        }
370
371        let isq_type = buffer.read_u8()? as usize;
372        if isq_type != QuantizedSerdeType::Afq as usize {
373            candle_core::bail!(
374                "ISQ type ({isq_type}) doesn't match expected type {}",
375                QuantizedSerdeType::Afq as usize
376            );
377        }
378
379        let has_bias = buffer.read_u8()? != 0;
380
381        let _acquired_load_guard = guard.acquire(device);
382        // Weight, scales, biases
383        let w_q = deserialize_tensor(&mut buffer, device)?;
384        let scales = deserialize_tensor(&mut buffer, device)?;
385        let biases = deserialize_tensor(&mut buffer, device)?;
386
387        // Bits and group size
388        let bits: AfqBits = buffer.read_u8()?.try_into()?;
389        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
390
391        let b = if has_bias {
392            Some(deserialize_tensor(&mut buffer, device)?)
393        } else {
394            None
395        };
396
397        Ok(Arc::new(Self {
398            w_q,
399            scales,
400            bias: b,
401            biases,
402            bits,
403            group_size,
404        }))
405    }
406    fn deserialize_ext_bias(
407        data: Cow<[u8]>,
408        device: &Device,
409        guard: QuantizeOntoGuard,
410    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
411    where
412        Self: Sized,
413    {
414        let mut buffer = Cursor::new(data.to_vec());
415
416        let version = buffer.read_u32::<LittleEndian>()?;
417        if let Err(e) = version_is_compatible(version) {
418            return Err(candle_core::Error::wrap(e));
419        }
420
421        let isq_type = buffer.read_u8()? as usize;
422        if isq_type != QuantizedSerdeType::Afq as usize {
423            candle_core::bail!(
424                "ISQ type ({isq_type}) doesn't match expected type {}",
425                QuantizedSerdeType::Afq as usize
426            );
427        }
428
429        let has_bias = buffer.read_u8()? != 0;
430
431        let _acquired_load_guard = guard.acquire(device);
432        // Weight, scales, biases
433        let w_q = deserialize_tensor(&mut buffer, device)?;
434        let scales = deserialize_tensor(&mut buffer, device)?;
435        let biases = deserialize_tensor(&mut buffer, device)?;
436
437        // Bits and group size
438        let bits: AfqBits = buffer.read_u8()?.try_into()?;
439        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
440
441        let b = if has_bias {
442            Some(deserialize_tensor(&mut buffer, device)?)
443        } else {
444            None
445        };
446
447        Ok((
448            Arc::new(Self {
449                w_q,
450                scales,
451                bias: None,
452                biases,
453                bits,
454                group_size,
455            }),
456            b,
457        ))
458    }
459}