1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy)]
26pub enum AfqBits {
27 Two = 2,
28 Three = 3,
29 Four = 4,
30 Six = 6,
31 Eight = 8,
32 Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36 type Error = candle_core::Error;
37 fn try_from(value: usize) -> Result<Self> {
38 match value {
39 2 => Ok(Self::Two),
40 3 => Ok(Self::Three),
41 4 => Ok(Self::Four),
42 6 => Ok(Self::Six),
43 8 => Ok(Self::Eight),
44 40 => Ok(Self::Mxfp4),
45 x => candle_core::bail!("Invalid AFQ bits {x}."),
46 }
47 }
48}
49
50impl TryFrom<u8> for AfqBits {
51 type Error = candle_core::Error;
52 fn try_from(value: u8) -> Result<Self> {
53 Self::try_from(value as usize)
54 }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default)]
59pub enum AfqGroupSize {
60 Low = 32,
61 #[default]
62 Med = 64,
63 High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67 type Error = candle_core::Error;
68 fn try_from(value: usize) -> Result<Self> {
69 match value {
70 32 => Ok(Self::Low),
71 64 => Ok(Self::Med),
72 128 => Ok(Self::High),
73 x => candle_core::bail!("Invalid AFQ group size {x}."),
74 }
75 }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79 type Error = candle_core::Error;
80 fn try_from(value: u8) -> Result<Self> {
81 Self::try_from(value as usize)
82 }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87 w_q: Tensor,
88 scales: Tensor,
89 biases: Tensor,
90 bias: Option<Tensor>,
91 bits: AfqBits,
92 group_size: AfqGroupSize,
93}
94
95impl QuantMethod for AfqLayer {
96 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
97 where
98 Self: Sized,
99 {
100 match method {
101 QuantMethodConfig::Gguf { .. }
102 | QuantMethodConfig::GptqAwq { .. }
103 | QuantMethodConfig::Hqq { .. }
104 | QuantMethodConfig::Dummy
105 | QuantMethodConfig::FP8 { .. }
106 | QuantMethodConfig::Bnb { .. }
107 | QuantMethodConfig::BlockwiseFP8 { .. }
108 | QuantMethodConfig::PerTensorFP8 { .. }
109 | QuantMethodConfig::Unquantized(_)
110 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
111 QuantMethodConfig::Afq {
112 weight,
113 bias,
114 bits,
115 group_size,
116 } => {
117 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
118
119 Ok(Self {
120 w_q,
121 scales,
122 biases,
123 bias,
124 bits,
125 group_size,
126 })
127 }
128 }
129 }
130
131 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
132 ops::afq_dequantize_op(
133 &self.w_q,
134 &self.scales,
135 &self.biases,
136 self.group_size,
137 self.bits,
138 )
139 }
140
141 fn forward(&self, x: &Tensor) -> Result<Tensor> {
142 ops::afq_mm_op(
143 x,
144 &self.w_q,
145 &self.scales,
146 &self.biases,
147 None,
148 None,
149 self.group_size,
150 self.bits,
151 true,
152 )
153 }
154
155 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
156 ops::afq_mm_op(
157 x,
158 &self.w_q,
159 &self.scales,
160 &self.biases,
161 None,
162 Some(indices),
163 self.group_size,
164 self.bits,
165 true,
166 )
167 }
168
169 fn quantized_act_type(&self) -> Option<DType> {
170 None
171 }
172
173 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
174 let dequant = self.dequantize_w()?;
175 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
176 weight: (dequant + delta)?,
177 bias: self.bias.clone(),
178 bits: self.bits,
179 group_size: self.group_size,
180 })?))
181 }
182
183 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
184 (self.scales.dtype(), self.scales.device().clone())
185 }
186
187 fn apply_isq(
188 self: Arc<Self>,
189 _dtype: Option<IsqType>,
190 _device: Device,
191 _n_quantized: &AtomicUsize,
192 _imatrix_weight: Option<Vec<f32>>,
193 _guard: QuantizeOntoGuard,
194 ) -> Result<Arc<dyn QuantMethod>> {
195 todo!()
196 }
197}
198
199impl AfqLayer {
200 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
201 let mut buffer = Cursor::new(data.to_vec());
202
203 let version = buffer.read_u32::<LittleEndian>()?;
204 if let Err(e) = version_is_compatible(version) {
205 return Err(candle_core::Error::wrap(e));
206 }
207
208 let isq_type = buffer.read_u8()? as usize;
209 if isq_type != QuantizedSerdeType::Afq as usize {
210 candle_core::bail!(
211 "ISQ type ({isq_type}) doesn't match expected type {}",
212 QuantizedSerdeType::Afq as usize
213 );
214 }
215
216 let has_bias = buffer.read_u8()? != 0;
217
218 fake_deserialize_tensor(&mut buffer)?;
220 fake_deserialize_tensor(&mut buffer)?;
221 fake_deserialize_tensor(&mut buffer)?;
222
223 let bits: AfqBits = buffer.read_u8()?.try_into()?;
225 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
226
227 if has_bias {
228 fake_deserialize_tensor(&mut buffer)?
229 }
230
231 match bits {
232 AfqBits::Two => Ok(IsqType::AFQ2),
233 AfqBits::Three => Ok(IsqType::AFQ3),
234 AfqBits::Four => Ok(IsqType::AFQ4),
235 AfqBits::Six => Ok(IsqType::AFQ6),
236 AfqBits::Eight => Ok(IsqType::AFQ8),
237 AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
238 }
239 }
240
241 pub fn afq_linear_b(
242 in_dim: usize,
243 out_dim: usize,
244 config: &QuantizedConfig,
245 bias: bool,
246 vb: ShardedVarBuilder,
247 ) -> Result<Arc<dyn QuantMethod>> {
248 let QuantizedConfig::Afq { bits, group_size } = config else {
249 candle_core::bail!("Unexpected quantization config.")
250 };
251
252 let w_q = vb.get_with_hints_dtype(
253 (out_dim, in_dim * bits / 32),
254 "weight",
255 Default::default(),
256 DType::U32,
257 )?;
258 let scales =
259 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
260 let biases =
261 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
262
263 let bias = if bias {
264 Some(vb.get((out_dim,), "bias")?)
265 } else {
266 None
267 };
268
269 Ok(Arc::new(Self {
270 w_q,
271 scales,
272 bias,
273 biases,
274 bits: AfqBits::try_from(*bits)?,
275 group_size: AfqGroupSize::try_from(*group_size)?,
276 }))
277 }
278
279 pub fn afq_packed_linear_b(
280 num_local_experts: usize,
281 in_dim: usize,
282 out_dim: usize,
283 config: &QuantizedConfig,
284 bias: bool,
285 vb: ShardedVarBuilder,
286 ) -> Result<Arc<dyn QuantMethod>> {
287 let QuantizedConfig::Afq { bits, group_size } = config else {
288 candle_core::bail!("Unexpected quantization config.")
289 };
290
291 let w_q = vb.get_with_hints_dtype(
292 (num_local_experts, out_dim, in_dim * bits / 32),
293 "weight",
294 Default::default(),
295 DType::U32,
296 )?;
297 let scales = vb.get_with_hints(
298 (num_local_experts, out_dim, in_dim / group_size),
299 "scales",
300 Default::default(),
301 )?;
302 let biases = vb.get_with_hints(
303 (num_local_experts, out_dim, in_dim / group_size),
304 "biases",
305 Default::default(),
306 )?;
307
308 let bias = if bias {
309 Some(vb.get((num_local_experts, out_dim), "bias")?)
310 } else {
311 None
312 };
313
314 Ok(Arc::new(Self {
315 w_q,
316 scales,
317 bias,
318 biases,
319 bits: AfqBits::try_from(*bits)?,
320 group_size: AfqGroupSize::try_from(*group_size)?,
321 }))
322 }
323}
324
325impl QuantizedSerde for AfqLayer {
326 fn name(&self) -> &'static str {
327 "afq-layer"
328 }
329 fn isq_serde_supported(&self) -> bool {
330 true
331 }
332 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
333 self.serialize_with_bias(self.bias.clone())
334 }
335 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
336 let mut buffer = Vec::new();
337
338 buffer.extend(&UQFF_VERSION.to_le_bytes());
340
341 buffer.push(QuantizedSerdeType::Afq as u8);
343
344 buffer.push(bias.is_some() as u8);
346
347 serialize_tensor(&mut buffer, &self.w_q)?;
349 serialize_tensor(&mut buffer, &self.scales)?;
350 serialize_tensor(&mut buffer, &self.biases)?;
351
352 buffer.push(self.bits as u8);
354 buffer.push(self.group_size as u8);
355
356 if let Some(bias) = &bias {
357 serialize_tensor(&mut buffer, bias)?;
359 }
360
361 Ok(Cow::from(buffer))
362 }
363 fn deserialize(
364 data: Cow<[u8]>,
365 device: &Device,
366 _comm: &Arc<Comm>,
367 guard: QuantizeOntoGuard,
368 ) -> Result<Arc<dyn QuantMethod>>
369 where
370 Self: Sized,
371 {
372 let mut buffer = Cursor::new(data);
373
374 let version = buffer.read_u32::<LittleEndian>()?;
375 if let Err(e) = version_is_compatible(version) {
376 return Err(candle_core::Error::wrap(e));
377 }
378
379 let isq_type = buffer.read_u8()? as usize;
380 if isq_type != QuantizedSerdeType::Afq as usize {
381 candle_core::bail!(
382 "ISQ type ({isq_type}) doesn't match expected type {}",
383 QuantizedSerdeType::Afq as usize
384 );
385 }
386
387 let has_bias = buffer.read_u8()? != 0;
388
389 let _acquired_load_guard = guard.acquire(device);
390 let w_q = deserialize_tensor(&mut buffer, device)?;
392 let scales = deserialize_tensor(&mut buffer, device)?;
393 let biases = deserialize_tensor(&mut buffer, device)?;
394
395 let bits: AfqBits = buffer.read_u8()?.try_into()?;
397 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
398
399 let b = if has_bias {
400 Some(deserialize_tensor(&mut buffer, device)?)
401 } else {
402 None
403 };
404
405 Ok(Arc::new(Self {
406 w_q,
407 scales,
408 bias: b,
409 biases,
410 bits,
411 group_size,
412 }))
413 }
414 fn deserialize_ext_bias(
415 data: Cow<[u8]>,
416 device: &Device,
417 guard: QuantizeOntoGuard,
418 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
419 where
420 Self: Sized,
421 {
422 let mut buffer = Cursor::new(data);
423
424 let version = buffer.read_u32::<LittleEndian>()?;
425 if let Err(e) = version_is_compatible(version) {
426 return Err(candle_core::Error::wrap(e));
427 }
428
429 let isq_type = buffer.read_u8()? as usize;
430 if isq_type != QuantizedSerdeType::Afq as usize {
431 candle_core::bail!(
432 "ISQ type ({isq_type}) doesn't match expected type {}",
433 QuantizedSerdeType::Afq as usize
434 );
435 }
436
437 let has_bias = buffer.read_u8()? != 0;
438
439 let _acquired_load_guard = guard.acquire(device);
440 let w_q = deserialize_tensor(&mut buffer, device)?;
442 let scales = deserialize_tensor(&mut buffer, device)?;
443 let biases = deserialize_tensor(&mut buffer, device)?;
444
445 let bits: AfqBits = buffer.read_u8()?.try_into()?;
447 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
448
449 let b = if has_bias {
450 Some(deserialize_tensor(&mut buffer, device)?)
451 } else {
452 None
453 };
454
455 Ok((
456 Arc::new(Self {
457 w_q,
458 scales,
459 bias: None,
460 biases,
461 bits,
462 group_size,
463 }),
464 b,
465 ))
466 }
467}