1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24 Two = 2,
25 Three = 3,
26 Four = 4,
27 Six = 6,
28 Eight = 8,
29}
30
31impl TryFrom<usize> for AfqBits {
32 type Error = candle_core::Error;
33 fn try_from(value: usize) -> Result<Self> {
34 match value {
35 2 => Ok(Self::Two),
36 3 => Ok(Self::Three),
37 4 => Ok(Self::Four),
38 6 => Ok(Self::Six),
39 8 => Ok(Self::Eight),
40 x => candle_core::bail!("Invalid AFQ bits {x}."),
41 }
42 }
43}
44
45impl TryFrom<u8> for AfqBits {
46 type Error = candle_core::Error;
47 fn try_from(value: u8) -> Result<Self> {
48 Self::try_from(value as usize)
49 }
50}
51
52#[repr(u8)]
53#[derive(Debug, Clone, Copy, Default)]
54pub enum AfqGroupSize {
55 Low = 32,
56 #[default]
57 Med = 64,
58 High = 128,
59}
60
61impl TryFrom<usize> for AfqGroupSize {
62 type Error = candle_core::Error;
63 fn try_from(value: usize) -> Result<Self> {
64 match value {
65 32 => Ok(Self::Low),
66 64 => Ok(Self::Med),
67 128 => Ok(Self::High),
68 x => candle_core::bail!("Invalid AFQ group size {x}."),
69 }
70 }
71}
72
73impl TryFrom<u8> for AfqGroupSize {
74 type Error = candle_core::Error;
75 fn try_from(value: u8) -> Result<Self> {
76 Self::try_from(value as usize)
77 }
78}
79
80#[derive(Debug)]
81pub struct AfqLayer {
82 w_q: Tensor,
83 scales: Tensor,
84 biases: Tensor,
85 bias: Option<Tensor>,
86 bits: AfqBits,
87 group_size: AfqGroupSize,
88}
89
90impl QuantMethod for AfqLayer {
91 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
92 where
93 Self: Sized,
94 {
95 match method {
96 QuantMethodConfig::Gguf { .. }
97 | QuantMethodConfig::Gptq { .. }
98 | QuantMethodConfig::Hqq { .. }
99 | QuantMethodConfig::Dummy
100 | QuantMethodConfig::FP8 { .. }
101 | QuantMethodConfig::Bnb { .. }
102 | QuantMethodConfig::BlockwiseFP8 { .. }
103 | QuantMethodConfig::Unquantized(_) => unreachable!(),
104 QuantMethodConfig::Afq {
105 weight,
106 bias,
107 bits,
108 group_size,
109 } => {
110 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
111
112 Ok(Self {
113 w_q,
114 scales,
115 biases,
116 bias,
117 bits,
118 group_size,
119 })
120 }
121 }
122 }
123
124 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
125 ops::afq_dequantize_op(
126 &self.w_q,
127 &self.scales,
128 &self.biases,
129 self.group_size,
130 self.bits,
131 )
132 }
133
134 fn forward(&self, x: &Tensor) -> Result<Tensor> {
135 ops::afq_mm_op(
136 x,
137 &self.w_q,
138 &self.scales,
139 &self.biases,
140 self.group_size,
141 self.bits,
142 true,
143 )
144 }
145
146 fn quantized_act_type(&self) -> Option<DType> {
147 None
148 }
149
150 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
151 let dequant = self.dequantize_w()?;
152 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
153 weight: (dequant + delta)?,
154 bias: self.bias.clone(),
155 bits: self.bits,
156 group_size: self.group_size,
157 })?))
158 }
159
160 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
161 (self.scales.dtype(), self.scales.device().clone())
162 }
163
164 fn apply_isq(
165 self: Arc<Self>,
166 _dtype: Option<IsqType>,
167 _device: Device,
168 _n_quantized: &AtomicUsize,
169 _imatrix_weight: Option<Vec<f32>>,
170 _guard: QuantizeOntoGuard,
171 ) -> Result<Arc<dyn QuantMethod>> {
172 todo!()
173 }
174}
175
176impl AfqLayer {
177 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
178 let mut buffer = Cursor::new(data.to_vec());
179
180 let version = buffer.read_u32::<LittleEndian>()?;
181 if let Err(e) = version_is_compatible(version) {
182 return Err(candle_core::Error::wrap(e));
183 }
184
185 let isq_type = buffer.read_u8()? as usize;
186 if isq_type != QuantizedSerdeType::Afq as usize {
187 candle_core::bail!(
188 "ISQ type ({isq_type}) doesn't match expected type {}",
189 QuantizedSerdeType::Afq as usize
190 );
191 }
192
193 let has_bias = buffer.read_u8()? != 0;
194
195 fake_deserialize_tensor(&mut buffer)?;
197 fake_deserialize_tensor(&mut buffer)?;
198 fake_deserialize_tensor(&mut buffer)?;
199
200 let bits: AfqBits = buffer.read_u8()?.try_into()?;
202 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
203
204 if has_bias {
205 fake_deserialize_tensor(&mut buffer)?
206 }
207
208 match bits {
209 AfqBits::Two => Ok(IsqType::AFQ2),
210 AfqBits::Three => Ok(IsqType::AFQ3),
211 AfqBits::Four => Ok(IsqType::AFQ4),
212 AfqBits::Six => Ok(IsqType::AFQ6),
213 AfqBits::Eight => Ok(IsqType::AFQ8),
214 }
215 }
216
217 pub fn afq_linear_b(
218 in_dim: usize,
219 out_dim: usize,
220 config: &QuantizedConfig,
221 bias: bool,
222 vb: ShardedVarBuilder,
223 ) -> Result<Arc<dyn QuantMethod>> {
224 let QuantizedConfig::Afq { bits, group_size } = config else {
225 candle_core::bail!("Unexpected quantization config.")
226 };
227
228 let w_q = vb.get_with_hints_dtype(
229 (out_dim, in_dim * bits / 32),
230 "weight",
231 Default::default(),
232 DType::U32,
233 )?;
234 let scales =
235 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
236 let biases =
237 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
238
239 let bias = if bias {
240 Some(vb.get((out_dim,), "bias")?)
241 } else {
242 None
243 };
244
245 Ok(Arc::new(Self {
246 w_q,
247 scales,
248 bias,
249 biases,
250 bits: AfqBits::try_from(*bits)?,
251 group_size: AfqGroupSize::try_from(*group_size)?,
252 }))
253 }
254}
255
256impl QuantizedSerde for AfqLayer {
257 fn name(&self) -> &'static str {
258 "afq-layer"
259 }
260 fn isq_serde_supported(&self) -> bool {
261 true
262 }
263 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
264 let mut buffer = Vec::new();
265
266 buffer.extend(&UQFF_VERSION.to_le_bytes());
268
269 buffer.push(QuantizedSerdeType::Afq as u8);
271
272 buffer.push(bias.is_some() as u8);
274
275 serialize_tensor(&mut buffer, &self.w_q)?;
277 serialize_tensor(&mut buffer, &self.scales)?;
278 serialize_tensor(&mut buffer, &self.biases)?;
279
280 buffer.push(self.bits as u8);
282 buffer.push(self.group_size as u8);
283
284 if let Some(bias) = &bias {
285 serialize_tensor(&mut buffer, bias)?;
287 }
288
289 Ok(Cow::from(buffer))
290 }
291 fn deserialize(
292 data: Cow<[u8]>,
293 device: &Device,
294 _comm: &Arc<Comm>,
295 guard: QuantizeOntoGuard,
296 ) -> Result<Arc<dyn QuantMethod>>
297 where
298 Self: Sized,
299 {
300 let mut buffer = Cursor::new(data.to_vec());
301
302 let version = buffer.read_u32::<LittleEndian>()?;
303 if let Err(e) = version_is_compatible(version) {
304 return Err(candle_core::Error::wrap(e));
305 }
306
307 let isq_type = buffer.read_u8()? as usize;
308 if isq_type != QuantizedSerdeType::Afq as usize {
309 candle_core::bail!(
310 "ISQ type ({isq_type}) doesn't match expected type {}",
311 QuantizedSerdeType::Afq as usize
312 );
313 }
314
315 let has_bias = buffer.read_u8()? != 0;
316
317 let _acquired_load_guard = guard.acquire();
318 let w_q = deserialize_tensor(&mut buffer, device)?;
320 let scales = deserialize_tensor(&mut buffer, device)?;
321 let biases = deserialize_tensor(&mut buffer, device)?;
322
323 let bits: AfqBits = buffer.read_u8()?.try_into()?;
325 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
326
327 let b = if has_bias {
328 Some(deserialize_tensor(&mut buffer, device)?)
329 } else {
330 None
331 };
332
333 Ok(Arc::new(Self {
334 w_q,
335 scales,
336 bias: b,
337 biases,
338 bits,
339 group_size,
340 }))
341 }
342 fn deserialize_ext_bias(
343 data: Cow<[u8]>,
344 device: &Device,
345 guard: QuantizeOntoGuard,
346 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
347 where
348 Self: Sized,
349 {
350 let mut buffer = Cursor::new(data.to_vec());
351
352 let version = buffer.read_u32::<LittleEndian>()?;
353 if let Err(e) = version_is_compatible(version) {
354 return Err(candle_core::Error::wrap(e));
355 }
356
357 let isq_type = buffer.read_u8()? as usize;
358 if isq_type != QuantizedSerdeType::Afq as usize {
359 candle_core::bail!(
360 "ISQ type ({isq_type}) doesn't match expected type {}",
361 QuantizedSerdeType::Afq as usize
362 );
363 }
364
365 let has_bias = buffer.read_u8()? != 0;
366
367 let _acquired_load_guard = guard.acquire();
368 let w_q = deserialize_tensor(&mut buffer, device)?;
370 let scales = deserialize_tensor(&mut buffer, device)?;
371 let biases = deserialize_tensor(&mut buffer, device)?;
372
373 let bits: AfqBits = buffer.read_u8()?.try_into()?;
375 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
376
377 let b = if has_bias {
378 Some(deserialize_tensor(&mut buffer, device)?)
379 } else {
380 None
381 };
382
383 Ok((
384 Arc::new(Self {
385 w_q,
386 scales,
387 bias: None,
388 biases,
389 bits,
390 group_size,
391 }),
392 b,
393 ))
394 }
395}