1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy)]
26pub enum AfqBits {
27 Two = 2,
28 Three = 3,
29 Four = 4,
30 Six = 6,
31 Eight = 8,
32 Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36 type Error = candle_core::Error;
37 fn try_from(value: usize) -> Result<Self> {
38 match value {
39 2 => Ok(Self::Two),
40 3 => Ok(Self::Three),
41 4 => Ok(Self::Four),
42 6 => Ok(Self::Six),
43 8 => Ok(Self::Eight),
44 40 => Ok(Self::Mxfp4),
45 x => candle_core::bail!("Invalid AFQ bits {x}."),
46 }
47 }
48}
49
50impl TryFrom<u8> for AfqBits {
51 type Error = candle_core::Error;
52 fn try_from(value: u8) -> Result<Self> {
53 Self::try_from(value as usize)
54 }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default)]
59pub enum AfqGroupSize {
60 Low = 32,
61 #[default]
62 Med = 64,
63 High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67 type Error = candle_core::Error;
68 fn try_from(value: usize) -> Result<Self> {
69 match value {
70 32 => Ok(Self::Low),
71 64 => Ok(Self::Med),
72 128 => Ok(Self::High),
73 x => candle_core::bail!("Invalid AFQ group size {x}."),
74 }
75 }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79 type Error = candle_core::Error;
80 fn try_from(value: u8) -> Result<Self> {
81 Self::try_from(value as usize)
82 }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87 w_q: Tensor,
88 scales: Tensor,
89 biases: Tensor,
90 bias: Option<Tensor>,
91 bits: AfqBits,
92 group_size: AfqGroupSize,
93}
94
95impl QuantMethod for AfqLayer {
96 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
97 where
98 Self: Sized,
99 {
100 match method {
101 QuantMethodConfig::Gguf { .. }
102 | QuantMethodConfig::GptqAwq { .. }
103 | QuantMethodConfig::Hqq { .. }
104 | QuantMethodConfig::Dummy
105 | QuantMethodConfig::FP8 { .. }
106 | QuantMethodConfig::Bnb { .. }
107 | QuantMethodConfig::BlockwiseFP8 { .. }
108 | QuantMethodConfig::Unquantized(_)
109 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
110 QuantMethodConfig::Afq {
111 weight,
112 bias,
113 bits,
114 group_size,
115 } => {
116 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
117
118 Ok(Self {
119 w_q,
120 scales,
121 biases,
122 bias,
123 bits,
124 group_size,
125 })
126 }
127 }
128 }
129
130 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
131 ops::afq_dequantize_op(
132 &self.w_q,
133 &self.scales,
134 &self.biases,
135 self.group_size,
136 self.bits,
137 )
138 }
139
140 fn forward(&self, x: &Tensor) -> Result<Tensor> {
141 ops::afq_mm_op(
142 x,
143 &self.w_q,
144 &self.scales,
145 &self.biases,
146 None,
147 None,
148 self.group_size,
149 self.bits,
150 true,
151 )
152 }
153
154 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
155 ops::afq_mm_op(
156 x,
157 &self.w_q,
158 &self.scales,
159 &self.biases,
160 None,
161 Some(indices),
162 self.group_size,
163 self.bits,
164 true,
165 )
166 }
167
168 fn quantized_act_type(&self) -> Option<DType> {
169 None
170 }
171
172 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
173 let dequant = self.dequantize_w()?;
174 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
175 weight: (dequant + delta)?,
176 bias: self.bias.clone(),
177 bits: self.bits,
178 group_size: self.group_size,
179 })?))
180 }
181
182 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
183 (self.scales.dtype(), self.scales.device().clone())
184 }
185
186 fn apply_isq(
187 self: Arc<Self>,
188 _dtype: Option<IsqType>,
189 _device: Device,
190 _n_quantized: &AtomicUsize,
191 _imatrix_weight: Option<Vec<f32>>,
192 _guard: QuantizeOntoGuard,
193 ) -> Result<Arc<dyn QuantMethod>> {
194 todo!()
195 }
196}
197
198impl AfqLayer {
199 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
200 let mut buffer = Cursor::new(data.to_vec());
201
202 let version = buffer.read_u32::<LittleEndian>()?;
203 if let Err(e) = version_is_compatible(version) {
204 return Err(candle_core::Error::wrap(e));
205 }
206
207 let isq_type = buffer.read_u8()? as usize;
208 if isq_type != QuantizedSerdeType::Afq as usize {
209 candle_core::bail!(
210 "ISQ type ({isq_type}) doesn't match expected type {}",
211 QuantizedSerdeType::Afq as usize
212 );
213 }
214
215 let has_bias = buffer.read_u8()? != 0;
216
217 fake_deserialize_tensor(&mut buffer)?;
219 fake_deserialize_tensor(&mut buffer)?;
220 fake_deserialize_tensor(&mut buffer)?;
221
222 let bits: AfqBits = buffer.read_u8()?.try_into()?;
224 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
225
226 if has_bias {
227 fake_deserialize_tensor(&mut buffer)?
228 }
229
230 match bits {
231 AfqBits::Two => Ok(IsqType::AFQ2),
232 AfqBits::Three => Ok(IsqType::AFQ3),
233 AfqBits::Four => Ok(IsqType::AFQ4),
234 AfqBits::Six => Ok(IsqType::AFQ6),
235 AfqBits::Eight => Ok(IsqType::AFQ8),
236 AfqBits::Mxfp4 => candle_core::bail!("mxfp4 is not supported as an ISQ type"),
237 }
238 }
239
240 pub fn afq_linear_b(
241 in_dim: usize,
242 out_dim: usize,
243 config: &QuantizedConfig,
244 bias: bool,
245 vb: ShardedVarBuilder,
246 ) -> Result<Arc<dyn QuantMethod>> {
247 let QuantizedConfig::Afq { bits, group_size } = config else {
248 candle_core::bail!("Unexpected quantization config.")
249 };
250
251 let w_q = vb.get_with_hints_dtype(
252 (out_dim, in_dim * bits / 32),
253 "weight",
254 Default::default(),
255 DType::U32,
256 )?;
257 let scales =
258 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
259 let biases =
260 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
261
262 let bias = if bias {
263 Some(vb.get((out_dim,), "bias")?)
264 } else {
265 None
266 };
267
268 Ok(Arc::new(Self {
269 w_q,
270 scales,
271 bias,
272 biases,
273 bits: AfqBits::try_from(*bits)?,
274 group_size: AfqGroupSize::try_from(*group_size)?,
275 }))
276 }
277
278 pub fn afq_packed_linear_b(
279 num_local_experts: usize,
280 in_dim: usize,
281 out_dim: usize,
282 config: &QuantizedConfig,
283 bias: bool,
284 vb: ShardedVarBuilder,
285 ) -> Result<Arc<dyn QuantMethod>> {
286 let QuantizedConfig::Afq { bits, group_size } = config else {
287 candle_core::bail!("Unexpected quantization config.")
288 };
289
290 let w_q = vb.get_with_hints_dtype(
291 (num_local_experts, out_dim, in_dim * bits / 32),
292 "weight",
293 Default::default(),
294 DType::U32,
295 )?;
296 let scales = vb.get_with_hints(
297 (num_local_experts, out_dim, in_dim / group_size),
298 "scales",
299 Default::default(),
300 )?;
301 let biases = vb.get_with_hints(
302 (num_local_experts, out_dim, in_dim / group_size),
303 "biases",
304 Default::default(),
305 )?;
306
307 let bias = if bias {
308 Some(vb.get((num_local_experts, out_dim), "bias")?)
309 } else {
310 None
311 };
312
313 Ok(Arc::new(Self {
314 w_q,
315 scales,
316 bias,
317 biases,
318 bits: AfqBits::try_from(*bits)?,
319 group_size: AfqGroupSize::try_from(*group_size)?,
320 }))
321 }
322}
323
324impl QuantizedSerde for AfqLayer {
325 fn name(&self) -> &'static str {
326 "afq-layer"
327 }
328 fn isq_serde_supported(&self) -> bool {
329 true
330 }
331 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
332 self.serialize_with_bias(self.bias.clone())
333 }
334 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
335 let mut buffer = Vec::new();
336
337 buffer.extend(&UQFF_VERSION.to_le_bytes());
339
340 buffer.push(QuantizedSerdeType::Afq as u8);
342
343 buffer.push(bias.is_some() as u8);
345
346 serialize_tensor(&mut buffer, &self.w_q)?;
348 serialize_tensor(&mut buffer, &self.scales)?;
349 serialize_tensor(&mut buffer, &self.biases)?;
350
351 buffer.push(self.bits as u8);
353 buffer.push(self.group_size as u8);
354
355 if let Some(bias) = &bias {
356 serialize_tensor(&mut buffer, bias)?;
358 }
359
360 Ok(Cow::from(buffer))
361 }
362 fn deserialize(
363 data: Cow<[u8]>,
364 device: &Device,
365 _comm: &Arc<Comm>,
366 guard: QuantizeOntoGuard,
367 ) -> Result<Arc<dyn QuantMethod>>
368 where
369 Self: Sized,
370 {
371 let mut buffer = Cursor::new(data);
372
373 let version = buffer.read_u32::<LittleEndian>()?;
374 if let Err(e) = version_is_compatible(version) {
375 return Err(candle_core::Error::wrap(e));
376 }
377
378 let isq_type = buffer.read_u8()? as usize;
379 if isq_type != QuantizedSerdeType::Afq as usize {
380 candle_core::bail!(
381 "ISQ type ({isq_type}) doesn't match expected type {}",
382 QuantizedSerdeType::Afq as usize
383 );
384 }
385
386 let has_bias = buffer.read_u8()? != 0;
387
388 let _acquired_load_guard = guard.acquire(device);
389 let w_q = deserialize_tensor(&mut buffer, device)?;
391 let scales = deserialize_tensor(&mut buffer, device)?;
392 let biases = deserialize_tensor(&mut buffer, device)?;
393
394 let bits: AfqBits = buffer.read_u8()?.try_into()?;
396 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
397
398 let b = if has_bias {
399 Some(deserialize_tensor(&mut buffer, device)?)
400 } else {
401 None
402 };
403
404 Ok(Arc::new(Self {
405 w_q,
406 scales,
407 bias: b,
408 biases,
409 bits,
410 group_size,
411 }))
412 }
413 fn deserialize_ext_bias(
414 data: Cow<[u8]>,
415 device: &Device,
416 guard: QuantizeOntoGuard,
417 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
418 where
419 Self: Sized,
420 {
421 let mut buffer = Cursor::new(data);
422
423 let version = buffer.read_u32::<LittleEndian>()?;
424 if let Err(e) = version_is_compatible(version) {
425 return Err(candle_core::Error::wrap(e));
426 }
427
428 let isq_type = buffer.read_u8()? as usize;
429 if isq_type != QuantizedSerdeType::Afq as usize {
430 candle_core::bail!(
431 "ISQ type ({isq_type}) doesn't match expected type {}",
432 QuantizedSerdeType::Afq as usize
433 );
434 }
435
436 let has_bias = buffer.read_u8()? != 0;
437
438 let _acquired_load_guard = guard.acquire(device);
439 let w_q = deserialize_tensor(&mut buffer, device)?;
441 let scales = deserialize_tensor(&mut buffer, device)?;
442 let biases = deserialize_tensor(&mut buffer, device)?;
443
444 let bits: AfqBits = buffer.read_u8()?.try_into()?;
446 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
447
448 let b = if has_bias {
449 Some(deserialize_tensor(&mut buffer, device)?)
450 } else {
451 None
452 };
453
454 Ok((
455 Arc::new(Self {
456 w_q,
457 scales,
458 bias: None,
459 biases,
460 bits,
461 group_size,
462 }),
463 b,
464 ))
465 }
466}