mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24    Two = 2,
25    Three = 3,
26    Four = 4,
27    Six = 6,
28    Eight = 8,
29}
30
31impl TryFrom<usize> for AfqBits {
32    type Error = candle_core::Error;
33    fn try_from(value: usize) -> Result<Self> {
34        match value {
35            2 => Ok(Self::Two),
36            3 => Ok(Self::Three),
37            4 => Ok(Self::Four),
38            6 => Ok(Self::Six),
39            8 => Ok(Self::Eight),
40            x => candle_core::bail!("Invalid AFQ bits {x}."),
41        }
42    }
43}
44
45impl TryFrom<u8> for AfqBits {
46    type Error = candle_core::Error;
47    fn try_from(value: u8) -> Result<Self> {
48        Self::try_from(value as usize)
49    }
50}
51
52#[repr(u8)]
53#[derive(Debug, Clone, Copy, Default)]
54pub enum AfqGroupSize {
55    Low = 32,
56    #[default]
57    Med = 64,
58    High = 128,
59}
60
61impl TryFrom<usize> for AfqGroupSize {
62    type Error = candle_core::Error;
63    fn try_from(value: usize) -> Result<Self> {
64        match value {
65            32 => Ok(Self::Low),
66            64 => Ok(Self::Med),
67            128 => Ok(Self::High),
68            x => candle_core::bail!("Invalid AFQ group size {x}."),
69        }
70    }
71}
72
73impl TryFrom<u8> for AfqGroupSize {
74    type Error = candle_core::Error;
75    fn try_from(value: u8) -> Result<Self> {
76        Self::try_from(value as usize)
77    }
78}
79
80#[derive(Debug)]
81pub struct AfqLayer {
82    w_q: Tensor,
83    scales: Tensor,
84    biases: Tensor,
85    bias: Option<Tensor>,
86    bits: AfqBits,
87    group_size: AfqGroupSize,
88}
89
90impl QuantMethod for AfqLayer {
91    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
92    where
93        Self: Sized,
94    {
95        match method {
96            QuantMethodConfig::Gguf { .. }
97            | QuantMethodConfig::Gptq { .. }
98            | QuantMethodConfig::Hqq { .. }
99            | QuantMethodConfig::Dummy
100            | QuantMethodConfig::FP8 { .. }
101            | QuantMethodConfig::Bnb { .. }
102            | QuantMethodConfig::BlockwiseFP8 { .. }
103            | QuantMethodConfig::Unquantized(_) => unreachable!(),
104            QuantMethodConfig::Afq {
105                weight,
106                bias,
107                bits,
108                group_size,
109            } => {
110                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
111
112                Ok(Self {
113                    w_q,
114                    scales,
115                    biases,
116                    bias,
117                    bits,
118                    group_size,
119                })
120            }
121        }
122    }
123
124    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
125        ops::afq_dequantize_op(
126            &self.w_q,
127            &self.scales,
128            &self.biases,
129            self.group_size,
130            self.bits,
131        )
132    }
133
134    fn forward(&self, x: &Tensor) -> Result<Tensor> {
135        ops::afq_mm_op(
136            x,
137            &self.w_q,
138            &self.scales,
139            &self.biases,
140            None,
141            None,
142            self.group_size,
143            self.bits,
144            true,
145        )
146    }
147
148    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
149        ops::afq_mm_op(
150            x,
151            &self.w_q,
152            &self.scales,
153            &self.biases,
154            None,
155            Some(indices),
156            self.group_size,
157            self.bits,
158            true,
159        )
160    }
161
162    fn quantized_act_type(&self) -> Option<DType> {
163        None
164    }
165
166    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
167        let dequant = self.dequantize_w()?;
168        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
169            weight: (dequant + delta)?,
170            bias: self.bias.clone(),
171            bits: self.bits,
172            group_size: self.group_size,
173        })?))
174    }
175
176    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
177        (self.scales.dtype(), self.scales.device().clone())
178    }
179
180    fn apply_isq(
181        self: Arc<Self>,
182        _dtype: Option<IsqType>,
183        _device: Device,
184        _n_quantized: &AtomicUsize,
185        _imatrix_weight: Option<Vec<f32>>,
186        _guard: QuantizeOntoGuard,
187    ) -> Result<Arc<dyn QuantMethod>> {
188        todo!()
189    }
190}
191
192impl AfqLayer {
193    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
194        let mut buffer = Cursor::new(data.to_vec());
195
196        let version = buffer.read_u32::<LittleEndian>()?;
197        if let Err(e) = version_is_compatible(version) {
198            return Err(candle_core::Error::wrap(e));
199        }
200
201        let isq_type = buffer.read_u8()? as usize;
202        if isq_type != QuantizedSerdeType::Afq as usize {
203            candle_core::bail!(
204                "ISQ type ({isq_type}) doesn't match expected type {}",
205                QuantizedSerdeType::Afq as usize
206            );
207        }
208
209        let has_bias = buffer.read_u8()? != 0;
210
211        // Weight, scales, biases
212        fake_deserialize_tensor(&mut buffer)?;
213        fake_deserialize_tensor(&mut buffer)?;
214        fake_deserialize_tensor(&mut buffer)?;
215
216        // Bits and group size
217        let bits: AfqBits = buffer.read_u8()?.try_into()?;
218        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
219
220        if has_bias {
221            fake_deserialize_tensor(&mut buffer)?
222        }
223
224        match bits {
225            AfqBits::Two => Ok(IsqType::AFQ2),
226            AfqBits::Three => Ok(IsqType::AFQ3),
227            AfqBits::Four => Ok(IsqType::AFQ4),
228            AfqBits::Six => Ok(IsqType::AFQ6),
229            AfqBits::Eight => Ok(IsqType::AFQ8),
230        }
231    }
232
233    pub fn afq_linear_b(
234        in_dim: usize,
235        out_dim: usize,
236        config: &QuantizedConfig,
237        bias: bool,
238        vb: ShardedVarBuilder,
239    ) -> Result<Arc<dyn QuantMethod>> {
240        let QuantizedConfig::Afq { bits, group_size } = config else {
241            candle_core::bail!("Unexpected quantization config.")
242        };
243
244        let w_q = vb.get_with_hints_dtype(
245            (out_dim, in_dim * bits / 32),
246            "weight",
247            Default::default(),
248            DType::U32,
249        )?;
250        let scales =
251            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
252        let biases =
253            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
254
255        let bias = if bias {
256            Some(vb.get((out_dim,), "bias")?)
257        } else {
258            None
259        };
260
261        Ok(Arc::new(Self {
262            w_q,
263            scales,
264            bias,
265            biases,
266            bits: AfqBits::try_from(*bits)?,
267            group_size: AfqGroupSize::try_from(*group_size)?,
268        }))
269    }
270
271    pub fn afq_packed_linear_b(
272        num_local_experts: usize,
273        in_dim: usize,
274        out_dim: usize,
275        config: &QuantizedConfig,
276        bias: bool,
277        vb: ShardedVarBuilder,
278    ) -> Result<Arc<dyn QuantMethod>> {
279        let QuantizedConfig::Afq { bits, group_size } = config else {
280            candle_core::bail!("Unexpected quantization config.")
281        };
282
283        let w_q = vb.get_with_hints_dtype(
284            (num_local_experts, out_dim, in_dim * bits / 32),
285            "weight",
286            Default::default(),
287            DType::U32,
288        )?;
289        let scales = vb.get_with_hints(
290            (num_local_experts, out_dim, in_dim / group_size),
291            "scales",
292            Default::default(),
293        )?;
294        let biases = vb.get_with_hints(
295            (num_local_experts, out_dim, in_dim / group_size),
296            "biases",
297            Default::default(),
298        )?;
299
300        let bias = if bias {
301            Some(vb.get((num_local_experts, out_dim), "bias")?)
302        } else {
303            None
304        };
305
306        Ok(Arc::new(Self {
307            w_q,
308            scales,
309            bias,
310            biases,
311            bits: AfqBits::try_from(*bits)?,
312            group_size: AfqGroupSize::try_from(*group_size)?,
313        }))
314    }
315}
316
317impl QuantizedSerde for AfqLayer {
318    fn name(&self) -> &'static str {
319        "afq-layer"
320    }
321    fn isq_serde_supported(&self) -> bool {
322        true
323    }
324    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
325        let mut buffer = Vec::new();
326
327        // Version is always first!
328        buffer.extend(&UQFF_VERSION.to_le_bytes());
329
330        // ISQ type for afq is 4
331        buffer.push(QuantizedSerdeType::Afq as u8);
332
333        // Has bias
334        buffer.push(bias.is_some() as u8);
335
336        // Weight, scales, biases
337        serialize_tensor(&mut buffer, &self.w_q)?;
338        serialize_tensor(&mut buffer, &self.scales)?;
339        serialize_tensor(&mut buffer, &self.biases)?;
340
341        // Bits and group size
342        buffer.push(self.bits as u8);
343        buffer.push(self.group_size as u8);
344
345        if let Some(bias) = &bias {
346            // Bias
347            serialize_tensor(&mut buffer, bias)?;
348        }
349
350        Ok(Cow::from(buffer))
351    }
352    fn deserialize(
353        data: Cow<[u8]>,
354        device: &Device,
355        _comm: &Arc<Comm>,
356        guard: QuantizeOntoGuard,
357    ) -> Result<Arc<dyn QuantMethod>>
358    where
359        Self: Sized,
360    {
361        let mut buffer = Cursor::new(data.to_vec());
362
363        let version = buffer.read_u32::<LittleEndian>()?;
364        if let Err(e) = version_is_compatible(version) {
365            return Err(candle_core::Error::wrap(e));
366        }
367
368        let isq_type = buffer.read_u8()? as usize;
369        if isq_type != QuantizedSerdeType::Afq as usize {
370            candle_core::bail!(
371                "ISQ type ({isq_type}) doesn't match expected type {}",
372                QuantizedSerdeType::Afq as usize
373            );
374        }
375
376        let has_bias = buffer.read_u8()? != 0;
377
378        let _acquired_load_guard = guard.acquire();
379        // Weight, scales, biases
380        let w_q = deserialize_tensor(&mut buffer, device)?;
381        let scales = deserialize_tensor(&mut buffer, device)?;
382        let biases = deserialize_tensor(&mut buffer, device)?;
383
384        // Bits and group size
385        let bits: AfqBits = buffer.read_u8()?.try_into()?;
386        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
387
388        let b = if has_bias {
389            Some(deserialize_tensor(&mut buffer, device)?)
390        } else {
391            None
392        };
393
394        Ok(Arc::new(Self {
395            w_q,
396            scales,
397            bias: b,
398            biases,
399            bits,
400            group_size,
401        }))
402    }
403    fn deserialize_ext_bias(
404        data: Cow<[u8]>,
405        device: &Device,
406        guard: QuantizeOntoGuard,
407    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
408    where
409        Self: Sized,
410    {
411        let mut buffer = Cursor::new(data.to_vec());
412
413        let version = buffer.read_u32::<LittleEndian>()?;
414        if let Err(e) = version_is_compatible(version) {
415            return Err(candle_core::Error::wrap(e));
416        }
417
418        let isq_type = buffer.read_u8()? as usize;
419        if isq_type != QuantizedSerdeType::Afq as usize {
420            candle_core::bail!(
421                "ISQ type ({isq_type}) doesn't match expected type {}",
422                QuantizedSerdeType::Afq as usize
423            );
424        }
425
426        let has_bias = buffer.read_u8()? != 0;
427
428        let _acquired_load_guard = guard.acquire();
429        // Weight, scales, biases
430        let w_q = deserialize_tensor(&mut buffer, device)?;
431        let scales = deserialize_tensor(&mut buffer, device)?;
432        let biases = deserialize_tensor(&mut buffer, device)?;
433
434        // Bits and group size
435        let bits: AfqBits = buffer.read_u8()?.try_into()?;
436        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
437
438        let b = if has_bias {
439            Some(deserialize_tensor(&mut buffer, device)?)
440        } else {
441            None
442        };
443
444        Ok((
445            Arc::new(Self {
446                w_q,
447                scales,
448                bias: None,
449                biases,
450                bits,
451                group_size,
452            }),
453            b,
454        ))
455    }
456}