1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24 Two = 2,
25 Three = 3,
26 Four = 4,
27 Six = 6,
28 Eight = 8,
29 Mxfp4 = 40,
30}
31
32impl TryFrom<usize> for AfqBits {
33 type Error = candle_core::Error;
34 fn try_from(value: usize) -> Result<Self> {
35 match value {
36 2 => Ok(Self::Two),
37 3 => Ok(Self::Three),
38 4 => Ok(Self::Four),
39 6 => Ok(Self::Six),
40 8 => Ok(Self::Eight),
41 40 => Ok(Self::Mxfp4),
42 x => candle_core::bail!("Invalid AFQ bits {x}."),
43 }
44 }
45}
46
47impl TryFrom<u8> for AfqBits {
48 type Error = candle_core::Error;
49 fn try_from(value: u8) -> Result<Self> {
50 Self::try_from(value as usize)
51 }
52}
53
54#[repr(u8)]
55#[derive(Debug, Clone, Copy, Default)]
56pub enum AfqGroupSize {
57 Low = 32,
58 #[default]
59 Med = 64,
60 High = 128,
61}
62
63impl TryFrom<usize> for AfqGroupSize {
64 type Error = candle_core::Error;
65 fn try_from(value: usize) -> Result<Self> {
66 match value {
67 32 => Ok(Self::Low),
68 64 => Ok(Self::Med),
69 128 => Ok(Self::High),
70 x => candle_core::bail!("Invalid AFQ group size {x}."),
71 }
72 }
73}
74
75impl TryFrom<u8> for AfqGroupSize {
76 type Error = candle_core::Error;
77 fn try_from(value: u8) -> Result<Self> {
78 Self::try_from(value as usize)
79 }
80}
81
82#[derive(Debug)]
83pub struct AfqLayer {
84 w_q: Tensor,
85 scales: Tensor,
86 biases: Tensor,
87 bias: Option<Tensor>,
88 bits: AfqBits,
89 group_size: AfqGroupSize,
90}
91
92impl QuantMethod for AfqLayer {
93 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
94 where
95 Self: Sized,
96 {
97 match method {
98 QuantMethodConfig::Gguf { .. }
99 | QuantMethodConfig::GptqAwq { .. }
100 | QuantMethodConfig::Hqq { .. }
101 | QuantMethodConfig::Dummy
102 | QuantMethodConfig::FP8 { .. }
103 | QuantMethodConfig::Bnb { .. }
104 | QuantMethodConfig::BlockwiseFP8 { .. }
105 | QuantMethodConfig::Unquantized(_)
106 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
107 QuantMethodConfig::Afq {
108 weight,
109 bias,
110 bits,
111 group_size,
112 } => {
113 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
114
115 Ok(Self {
116 w_q,
117 scales,
118 biases,
119 bias,
120 bits,
121 group_size,
122 })
123 }
124 }
125 }
126
127 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
128 ops::afq_dequantize_op(
129 &self.w_q,
130 &self.scales,
131 &self.biases,
132 self.group_size,
133 self.bits,
134 )
135 }
136
137 fn forward(&self, x: &Tensor) -> Result<Tensor> {
138 ops::afq_mm_op(
139 x,
140 &self.w_q,
141 &self.scales,
142 &self.biases,
143 None,
144 None,
145 self.group_size,
146 self.bits,
147 true,
148 )
149 }
150
151 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
152 ops::afq_mm_op(
153 x,
154 &self.w_q,
155 &self.scales,
156 &self.biases,
157 None,
158 Some(indices),
159 self.group_size,
160 self.bits,
161 true,
162 )
163 }
164
165 fn quantized_act_type(&self) -> Option<DType> {
166 None
167 }
168
169 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
170 let dequant = self.dequantize_w()?;
171 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
172 weight: (dequant + delta)?,
173 bias: self.bias.clone(),
174 bits: self.bits,
175 group_size: self.group_size,
176 })?))
177 }
178
179 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
180 (self.scales.dtype(), self.scales.device().clone())
181 }
182
183 fn apply_isq(
184 self: Arc<Self>,
185 _dtype: Option<IsqType>,
186 _device: Device,
187 _n_quantized: &AtomicUsize,
188 _imatrix_weight: Option<Vec<f32>>,
189 _guard: QuantizeOntoGuard,
190 ) -> Result<Arc<dyn QuantMethod>> {
191 todo!()
192 }
193}
194
195impl AfqLayer {
196 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
197 let mut buffer = Cursor::new(data.to_vec());
198
199 let version = buffer.read_u32::<LittleEndian>()?;
200 if let Err(e) = version_is_compatible(version) {
201 return Err(candle_core::Error::wrap(e));
202 }
203
204 let isq_type = buffer.read_u8()? as usize;
205 if isq_type != QuantizedSerdeType::Afq as usize {
206 candle_core::bail!(
207 "ISQ type ({isq_type}) doesn't match expected type {}",
208 QuantizedSerdeType::Afq as usize
209 );
210 }
211
212 let has_bias = buffer.read_u8()? != 0;
213
214 fake_deserialize_tensor(&mut buffer)?;
216 fake_deserialize_tensor(&mut buffer)?;
217 fake_deserialize_tensor(&mut buffer)?;
218
219 let bits: AfqBits = buffer.read_u8()?.try_into()?;
221 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
222
223 if has_bias {
224 fake_deserialize_tensor(&mut buffer)?
225 }
226
227 match bits {
228 AfqBits::Two => Ok(IsqType::AFQ2),
229 AfqBits::Three => Ok(IsqType::AFQ3),
230 AfqBits::Four => Ok(IsqType::AFQ4),
231 AfqBits::Six => Ok(IsqType::AFQ6),
232 AfqBits::Eight => Ok(IsqType::AFQ8),
233 AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
234 }
235 }
236
237 pub fn afq_linear_b(
238 in_dim: usize,
239 out_dim: usize,
240 config: &QuantizedConfig,
241 bias: bool,
242 vb: ShardedVarBuilder,
243 ) -> Result<Arc<dyn QuantMethod>> {
244 let QuantizedConfig::Afq { bits, group_size } = config else {
245 candle_core::bail!("Unexpected quantization config.")
246 };
247
248 let w_q = vb.get_with_hints_dtype(
249 (out_dim, in_dim * bits / 32),
250 "weight",
251 Default::default(),
252 DType::U32,
253 )?;
254 let scales =
255 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
256 let biases =
257 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
258
259 let bias = if bias {
260 Some(vb.get((out_dim,), "bias")?)
261 } else {
262 None
263 };
264
265 Ok(Arc::new(Self {
266 w_q,
267 scales,
268 bias,
269 biases,
270 bits: AfqBits::try_from(*bits)?,
271 group_size: AfqGroupSize::try_from(*group_size)?,
272 }))
273 }
274
275 pub fn afq_packed_linear_b(
276 num_local_experts: usize,
277 in_dim: usize,
278 out_dim: usize,
279 config: &QuantizedConfig,
280 bias: bool,
281 vb: ShardedVarBuilder,
282 ) -> Result<Arc<dyn QuantMethod>> {
283 let QuantizedConfig::Afq { bits, group_size } = config else {
284 candle_core::bail!("Unexpected quantization config.")
285 };
286
287 let w_q = vb.get_with_hints_dtype(
288 (num_local_experts, out_dim, in_dim * bits / 32),
289 "weight",
290 Default::default(),
291 DType::U32,
292 )?;
293 let scales = vb.get_with_hints(
294 (num_local_experts, out_dim, in_dim / group_size),
295 "scales",
296 Default::default(),
297 )?;
298 let biases = vb.get_with_hints(
299 (num_local_experts, out_dim, in_dim / group_size),
300 "biases",
301 Default::default(),
302 )?;
303
304 let bias = if bias {
305 Some(vb.get((num_local_experts, out_dim), "bias")?)
306 } else {
307 None
308 };
309
310 Ok(Arc::new(Self {
311 w_q,
312 scales,
313 bias,
314 biases,
315 bits: AfqBits::try_from(*bits)?,
316 group_size: AfqGroupSize::try_from(*group_size)?,
317 }))
318 }
319}
320
321impl QuantizedSerde for AfqLayer {
322 fn name(&self) -> &'static str {
323 "afq-layer"
324 }
325 fn isq_serde_supported(&self) -> bool {
326 true
327 }
328 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
329 self.serialize_with_bias(self.bias.clone())
330 }
331 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
332 let mut buffer = Vec::new();
333
334 buffer.extend(&UQFF_VERSION.to_le_bytes());
336
337 buffer.push(QuantizedSerdeType::Afq as u8);
339
340 buffer.push(bias.is_some() as u8);
342
343 serialize_tensor(&mut buffer, &self.w_q)?;
345 serialize_tensor(&mut buffer, &self.scales)?;
346 serialize_tensor(&mut buffer, &self.biases)?;
347
348 buffer.push(self.bits as u8);
350 buffer.push(self.group_size as u8);
351
352 if let Some(bias) = &bias {
353 serialize_tensor(&mut buffer, bias)?;
355 }
356
357 Ok(Cow::from(buffer))
358 }
359 fn deserialize(
360 data: Cow<[u8]>,
361 device: &Device,
362 _comm: &Arc<Comm>,
363 guard: QuantizeOntoGuard,
364 ) -> Result<Arc<dyn QuantMethod>>
365 where
366 Self: Sized,
367 {
368 let mut buffer = Cursor::new(data);
369
370 let version = buffer.read_u32::<LittleEndian>()?;
371 if let Err(e) = version_is_compatible(version) {
372 return Err(candle_core::Error::wrap(e));
373 }
374
375 let isq_type = buffer.read_u8()? as usize;
376 if isq_type != QuantizedSerdeType::Afq as usize {
377 candle_core::bail!(
378 "ISQ type ({isq_type}) doesn't match expected type {}",
379 QuantizedSerdeType::Afq as usize
380 );
381 }
382
383 let has_bias = buffer.read_u8()? != 0;
384
385 let _acquired_load_guard = guard.acquire(device);
386 let w_q = deserialize_tensor(&mut buffer, device)?;
388 let scales = deserialize_tensor(&mut buffer, device)?;
389 let biases = deserialize_tensor(&mut buffer, device)?;
390
391 let bits: AfqBits = buffer.read_u8()?.try_into()?;
393 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
394
395 let b = if has_bias {
396 Some(deserialize_tensor(&mut buffer, device)?)
397 } else {
398 None
399 };
400
401 Ok(Arc::new(Self {
402 w_q,
403 scales,
404 bias: b,
405 biases,
406 bits,
407 group_size,
408 }))
409 }
410 fn deserialize_ext_bias(
411 data: Cow<[u8]>,
412 device: &Device,
413 guard: QuantizeOntoGuard,
414 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
415 where
416 Self: Sized,
417 {
418 let mut buffer = Cursor::new(data);
419
420 let version = buffer.read_u32::<LittleEndian>()?;
421 if let Err(e) = version_is_compatible(version) {
422 return Err(candle_core::Error::wrap(e));
423 }
424
425 let isq_type = buffer.read_u8()? as usize;
426 if isq_type != QuantizedSerdeType::Afq as usize {
427 candle_core::bail!(
428 "ISQ type ({isq_type}) doesn't match expected type {}",
429 QuantizedSerdeType::Afq as usize
430 );
431 }
432
433 let has_bias = buffer.read_u8()? != 0;
434
435 let _acquired_load_guard = guard.acquire(device);
436 let w_q = deserialize_tensor(&mut buffer, device)?;
438 let scales = deserialize_tensor(&mut buffer, device)?;
439 let biases = deserialize_tensor(&mut buffer, device)?;
440
441 let bits: AfqBits = buffer.read_u8()?.try_into()?;
443 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
444
445 let b = if has_bias {
446 Some(deserialize_tensor(&mut buffer, device)?)
447 } else {
448 None
449 };
450
451 Ok((
452 Arc::new(Self {
453 w_q,
454 scales,
455 bias: None,
456 biases,
457 bits,
458 group_size,
459 }),
460 b,
461 ))
462 }
463}