mistralrs_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24    Two = 2,
25    Three = 3,
26    Four = 4,
27    Six = 6,
28    Eight = 8,
29}
30
31impl TryFrom<usize> for AfqBits {
32    type Error = candle_core::Error;
33    fn try_from(value: usize) -> Result<Self> {
34        match value {
35            2 => Ok(Self::Two),
36            3 => Ok(Self::Three),
37            4 => Ok(Self::Four),
38            6 => Ok(Self::Six),
39            8 => Ok(Self::Eight),
40            x => candle_core::bail!("Invalid AFQ bits {x}."),
41        }
42    }
43}
44
45impl TryFrom<u8> for AfqBits {
46    type Error = candle_core::Error;
47    fn try_from(value: u8) -> Result<Self> {
48        Self::try_from(value as usize)
49    }
50}
51
52#[repr(u8)]
53#[derive(Debug, Clone, Copy, Default)]
54pub enum AfqGroupSize {
55    Low = 32,
56    #[default]
57    Med = 64,
58    High = 128,
59}
60
61impl TryFrom<usize> for AfqGroupSize {
62    type Error = candle_core::Error;
63    fn try_from(value: usize) -> Result<Self> {
64        match value {
65            32 => Ok(Self::Low),
66            64 => Ok(Self::Med),
67            128 => Ok(Self::High),
68            x => candle_core::bail!("Invalid AFQ group size {x}."),
69        }
70    }
71}
72
73impl TryFrom<u8> for AfqGroupSize {
74    type Error = candle_core::Error;
75    fn try_from(value: u8) -> Result<Self> {
76        Self::try_from(value as usize)
77    }
78}
79
80#[derive(Debug)]
81pub struct AfqLayer {
82    w_q: Tensor,
83    scales: Tensor,
84    biases: Tensor,
85    bias: Option<Tensor>,
86    bits: AfqBits,
87    group_size: AfqGroupSize,
88}
89
90impl QuantMethod for AfqLayer {
91    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
92    where
93        Self: Sized,
94    {
95        match method {
96            QuantMethodConfig::Gguf { .. }
97            | QuantMethodConfig::Gptq { .. }
98            | QuantMethodConfig::Hqq { .. }
99            | QuantMethodConfig::Dummy
100            | QuantMethodConfig::FP8 { .. }
101            | QuantMethodConfig::Bnb { .. }
102            | QuantMethodConfig::BlockwiseFP8 { .. }
103            | QuantMethodConfig::Unquantized(_) => unreachable!(),
104            QuantMethodConfig::Afq {
105                weight,
106                bias,
107                bits,
108                group_size,
109            } => {
110                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
111
112                Ok(Self {
113                    w_q,
114                    scales,
115                    biases,
116                    bias,
117                    bits,
118                    group_size,
119                })
120            }
121        }
122    }
123
124    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
125        ops::afq_dequantize_op(
126            &self.w_q,
127            &self.scales,
128            &self.biases,
129            self.group_size,
130            self.bits,
131        )
132    }
133
134    fn forward(&self, x: &Tensor) -> Result<Tensor> {
135        ops::afq_mm_op(
136            x,
137            &self.w_q,
138            &self.scales,
139            &self.biases,
140            self.group_size,
141            self.bits,
142            true,
143        )
144    }
145
146    fn quantized_act_type(&self) -> Option<DType> {
147        None
148    }
149
150    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
151        let dequant = self.dequantize_w()?;
152        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
153            weight: (dequant + delta)?,
154            bias: self.bias.clone(),
155            bits: self.bits,
156            group_size: self.group_size,
157        })?))
158    }
159
160    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
161        (self.scales.dtype(), self.scales.device().clone())
162    }
163
164    fn apply_isq(
165        self: Arc<Self>,
166        _dtype: Option<IsqType>,
167        _device: Device,
168        _n_quantized: &AtomicUsize,
169        _imatrix_weight: Option<Vec<f32>>,
170        _guard: QuantizeOntoGuard,
171    ) -> Result<Arc<dyn QuantMethod>> {
172        todo!()
173    }
174}
175
176impl AfqLayer {
177    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
178        let mut buffer = Cursor::new(data.to_vec());
179
180        let version = buffer.read_u32::<LittleEndian>()?;
181        if let Err(e) = version_is_compatible(version) {
182            return Err(candle_core::Error::wrap(e));
183        }
184
185        let isq_type = buffer.read_u8()? as usize;
186        if isq_type != QuantizedSerdeType::Afq as usize {
187            candle_core::bail!(
188                "ISQ type ({isq_type}) doesn't match expected type {}",
189                QuantizedSerdeType::Afq as usize
190            );
191        }
192
193        let has_bias = buffer.read_u8()? != 0;
194
195        // Weight, scales, biases
196        fake_deserialize_tensor(&mut buffer)?;
197        fake_deserialize_tensor(&mut buffer)?;
198        fake_deserialize_tensor(&mut buffer)?;
199
200        // Bits and group size
201        let bits: AfqBits = buffer.read_u8()?.try_into()?;
202        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
203
204        if has_bias {
205            fake_deserialize_tensor(&mut buffer)?
206        }
207
208        match bits {
209            AfqBits::Two => Ok(IsqType::AFQ2),
210            AfqBits::Three => Ok(IsqType::AFQ3),
211            AfqBits::Four => Ok(IsqType::AFQ4),
212            AfqBits::Six => Ok(IsqType::AFQ6),
213            AfqBits::Eight => Ok(IsqType::AFQ8),
214        }
215    }
216
217    pub fn afq_linear_b(
218        in_dim: usize,
219        out_dim: usize,
220        config: &QuantizedConfig,
221        bias: bool,
222        vb: ShardedVarBuilder,
223    ) -> Result<Arc<dyn QuantMethod>> {
224        let QuantizedConfig::Afq { bits, group_size } = config else {
225            candle_core::bail!("Unexpected quantization config.")
226        };
227
228        let w_q = vb.get_with_hints_dtype(
229            (out_dim, in_dim * bits / 32),
230            "weight",
231            Default::default(),
232            DType::U32,
233        )?;
234        let scales =
235            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
236        let biases =
237            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
238
239        let bias = if bias {
240            Some(vb.get((out_dim,), "bias")?)
241        } else {
242            None
243        };
244
245        Ok(Arc::new(Self {
246            w_q,
247            scales,
248            bias,
249            biases,
250            bits: AfqBits::try_from(*bits)?,
251            group_size: AfqGroupSize::try_from(*group_size)?,
252        }))
253    }
254}
255
256impl QuantizedSerde for AfqLayer {
257    fn name(&self) -> &'static str {
258        "afq-layer"
259    }
260    fn isq_serde_supported(&self) -> bool {
261        true
262    }
263    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
264        let mut buffer = Vec::new();
265
266        // Version is always first!
267        buffer.extend(&UQFF_VERSION.to_le_bytes());
268
269        // ISQ type for afq is 4
270        buffer.push(QuantizedSerdeType::Afq as u8);
271
272        // Has bias
273        buffer.push(bias.is_some() as u8);
274
275        // Weight, scales, biases
276        serialize_tensor(&mut buffer, &self.w_q)?;
277        serialize_tensor(&mut buffer, &self.scales)?;
278        serialize_tensor(&mut buffer, &self.biases)?;
279
280        // Bits and group size
281        buffer.push(self.bits as u8);
282        buffer.push(self.group_size as u8);
283
284        if let Some(bias) = &bias {
285            // Bias
286            serialize_tensor(&mut buffer, bias)?;
287        }
288
289        Ok(Cow::from(buffer))
290    }
291    fn deserialize(
292        data: Cow<[u8]>,
293        device: &Device,
294        _comm: &Arc<Comm>,
295        guard: QuantizeOntoGuard,
296    ) -> Result<Arc<dyn QuantMethod>>
297    where
298        Self: Sized,
299    {
300        let mut buffer = Cursor::new(data.to_vec());
301
302        let version = buffer.read_u32::<LittleEndian>()?;
303        if let Err(e) = version_is_compatible(version) {
304            return Err(candle_core::Error::wrap(e));
305        }
306
307        let isq_type = buffer.read_u8()? as usize;
308        if isq_type != QuantizedSerdeType::Afq as usize {
309            candle_core::bail!(
310                "ISQ type ({isq_type}) doesn't match expected type {}",
311                QuantizedSerdeType::Afq as usize
312            );
313        }
314
315        let has_bias = buffer.read_u8()? != 0;
316
317        let _acquired_load_guard = guard.acquire();
318        // Weight, scales, biases
319        let w_q = deserialize_tensor(&mut buffer, device)?;
320        let scales = deserialize_tensor(&mut buffer, device)?;
321        let biases = deserialize_tensor(&mut buffer, device)?;
322
323        // Bits and group size
324        let bits: AfqBits = buffer.read_u8()?.try_into()?;
325        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
326
327        let b = if has_bias {
328            Some(deserialize_tensor(&mut buffer, device)?)
329        } else {
330            None
331        };
332
333        Ok(Arc::new(Self {
334            w_q,
335            scales,
336            bias: b,
337            biases,
338            bits,
339            group_size,
340        }))
341    }
342    fn deserialize_ext_bias(
343        data: Cow<[u8]>,
344        device: &Device,
345        guard: QuantizeOntoGuard,
346    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
347    where
348        Self: Sized,
349    {
350        let mut buffer = Cursor::new(data.to_vec());
351
352        let version = buffer.read_u32::<LittleEndian>()?;
353        if let Err(e) = version_is_compatible(version) {
354            return Err(candle_core::Error::wrap(e));
355        }
356
357        let isq_type = buffer.read_u8()? as usize;
358        if isq_type != QuantizedSerdeType::Afq as usize {
359            candle_core::bail!(
360                "ISQ type ({isq_type}) doesn't match expected type {}",
361                QuantizedSerdeType::Afq as usize
362            );
363        }
364
365        let has_bias = buffer.read_u8()? != 0;
366
367        let _acquired_load_guard = guard.acquire();
368        // Weight, scales, biases
369        let w_q = deserialize_tensor(&mut buffer, device)?;
370        let scales = deserialize_tensor(&mut buffer, device)?;
371        let biases = deserialize_tensor(&mut buffer, device)?;
372
373        // Bits and group size
374        let bits: AfqBits = buffer.read_u8()?.try_into()?;
375        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
376
377        let b = if has_bias {
378            Some(deserialize_tensor(&mut buffer, device)?)
379        } else {
380            None
381        };
382
383        Ok((
384            Arc::new(Self {
385                w_q,
386                scales,
387                bias: None,
388                biases,
389                bits,
390                group_size,
391            }),
392            b,
393        ))
394    }
395}