1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19mod ops;
20
21#[repr(u8)]
22#[derive(Debug, Clone, Copy)]
23pub enum AfqBits {
24 Two = 2,
25 Three = 3,
26 Four = 4,
27 Six = 6,
28 Eight = 8,
29}
30
31impl TryFrom<usize> for AfqBits {
32 type Error = candle_core::Error;
33 fn try_from(value: usize) -> Result<Self> {
34 match value {
35 2 => Ok(Self::Two),
36 3 => Ok(Self::Three),
37 4 => Ok(Self::Four),
38 6 => Ok(Self::Six),
39 8 => Ok(Self::Eight),
40 x => candle_core::bail!("Invalid AFQ bits {x}."),
41 }
42 }
43}
44
45impl TryFrom<u8> for AfqBits {
46 type Error = candle_core::Error;
47 fn try_from(value: u8) -> Result<Self> {
48 Self::try_from(value as usize)
49 }
50}
51
52#[repr(u8)]
53#[derive(Debug, Clone, Copy, Default)]
54pub enum AfqGroupSize {
55 Low = 32,
56 #[default]
57 Med = 64,
58 High = 128,
59}
60
61impl TryFrom<usize> for AfqGroupSize {
62 type Error = candle_core::Error;
63 fn try_from(value: usize) -> Result<Self> {
64 match value {
65 32 => Ok(Self::Low),
66 64 => Ok(Self::Med),
67 128 => Ok(Self::High),
68 x => candle_core::bail!("Invalid AFQ group size {x}."),
69 }
70 }
71}
72
73impl TryFrom<u8> for AfqGroupSize {
74 type Error = candle_core::Error;
75 fn try_from(value: u8) -> Result<Self> {
76 Self::try_from(value as usize)
77 }
78}
79
80#[derive(Debug)]
81pub struct AfqLayer {
82 w_q: Tensor,
83 scales: Tensor,
84 biases: Tensor,
85 bias: Option<Tensor>,
86 bits: AfqBits,
87 group_size: AfqGroupSize,
88}
89
90impl QuantMethod for AfqLayer {
91 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
92 where
93 Self: Sized,
94 {
95 match method {
96 QuantMethodConfig::Gguf { .. }
97 | QuantMethodConfig::Gptq { .. }
98 | QuantMethodConfig::Hqq { .. }
99 | QuantMethodConfig::Dummy
100 | QuantMethodConfig::FP8 { .. }
101 | QuantMethodConfig::Bnb { .. }
102 | QuantMethodConfig::BlockwiseFP8 { .. }
103 | QuantMethodConfig::Unquantized(_) => unreachable!(),
104 QuantMethodConfig::Afq {
105 weight,
106 bias,
107 bits,
108 group_size,
109 } => {
110 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
111
112 Ok(Self {
113 w_q,
114 scales,
115 biases,
116 bias,
117 bits,
118 group_size,
119 })
120 }
121 }
122 }
123
124 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
125 ops::afq_dequantize_op(
126 &self.w_q,
127 &self.scales,
128 &self.biases,
129 self.group_size,
130 self.bits,
131 )
132 }
133
134 fn forward(&self, x: &Tensor) -> Result<Tensor> {
135 ops::afq_mm_op(
136 x,
137 &self.w_q,
138 &self.scales,
139 &self.biases,
140 None,
141 None,
142 self.group_size,
143 self.bits,
144 true,
145 )
146 }
147
148 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
149 ops::afq_mm_op(
150 x,
151 &self.w_q,
152 &self.scales,
153 &self.biases,
154 None,
155 Some(indices),
156 self.group_size,
157 self.bits,
158 true,
159 )
160 }
161
162 fn quantized_act_type(&self) -> Option<DType> {
163 None
164 }
165
166 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
167 let dequant = self.dequantize_w()?;
168 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
169 weight: (dequant + delta)?,
170 bias: self.bias.clone(),
171 bits: self.bits,
172 group_size: self.group_size,
173 })?))
174 }
175
176 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
177 (self.scales.dtype(), self.scales.device().clone())
178 }
179
180 fn apply_isq(
181 self: Arc<Self>,
182 _dtype: Option<IsqType>,
183 _device: Device,
184 _n_quantized: &AtomicUsize,
185 _imatrix_weight: Option<Vec<f32>>,
186 _guard: QuantizeOntoGuard,
187 ) -> Result<Arc<dyn QuantMethod>> {
188 todo!()
189 }
190}
191
192impl AfqLayer {
193 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
194 let mut buffer = Cursor::new(data.to_vec());
195
196 let version = buffer.read_u32::<LittleEndian>()?;
197 if let Err(e) = version_is_compatible(version) {
198 return Err(candle_core::Error::wrap(e));
199 }
200
201 let isq_type = buffer.read_u8()? as usize;
202 if isq_type != QuantizedSerdeType::Afq as usize {
203 candle_core::bail!(
204 "ISQ type ({isq_type}) doesn't match expected type {}",
205 QuantizedSerdeType::Afq as usize
206 );
207 }
208
209 let has_bias = buffer.read_u8()? != 0;
210
211 fake_deserialize_tensor(&mut buffer)?;
213 fake_deserialize_tensor(&mut buffer)?;
214 fake_deserialize_tensor(&mut buffer)?;
215
216 let bits: AfqBits = buffer.read_u8()?.try_into()?;
218 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
219
220 if has_bias {
221 fake_deserialize_tensor(&mut buffer)?
222 }
223
224 match bits {
225 AfqBits::Two => Ok(IsqType::AFQ2),
226 AfqBits::Three => Ok(IsqType::AFQ3),
227 AfqBits::Four => Ok(IsqType::AFQ4),
228 AfqBits::Six => Ok(IsqType::AFQ6),
229 AfqBits::Eight => Ok(IsqType::AFQ8),
230 }
231 }
232
233 pub fn afq_linear_b(
234 in_dim: usize,
235 out_dim: usize,
236 config: &QuantizedConfig,
237 bias: bool,
238 vb: ShardedVarBuilder,
239 ) -> Result<Arc<dyn QuantMethod>> {
240 let QuantizedConfig::Afq { bits, group_size } = config else {
241 candle_core::bail!("Unexpected quantization config.")
242 };
243
244 let w_q = vb.get_with_hints_dtype(
245 (out_dim, in_dim * bits / 32),
246 "weight",
247 Default::default(),
248 DType::U32,
249 )?;
250 let scales =
251 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
252 let biases =
253 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
254
255 let bias = if bias {
256 Some(vb.get((out_dim,), "bias")?)
257 } else {
258 None
259 };
260
261 Ok(Arc::new(Self {
262 w_q,
263 scales,
264 bias,
265 biases,
266 bits: AfqBits::try_from(*bits)?,
267 group_size: AfqGroupSize::try_from(*group_size)?,
268 }))
269 }
270
271 pub fn afq_packed_linear_b(
272 num_local_experts: usize,
273 in_dim: usize,
274 out_dim: usize,
275 config: &QuantizedConfig,
276 bias: bool,
277 vb: ShardedVarBuilder,
278 ) -> Result<Arc<dyn QuantMethod>> {
279 let QuantizedConfig::Afq { bits, group_size } = config else {
280 candle_core::bail!("Unexpected quantization config.")
281 };
282
283 let w_q = vb.get_with_hints_dtype(
284 (num_local_experts, out_dim, in_dim * bits / 32),
285 "weight",
286 Default::default(),
287 DType::U32,
288 )?;
289 let scales = vb.get_with_hints(
290 (num_local_experts, out_dim, in_dim / group_size),
291 "scales",
292 Default::default(),
293 )?;
294 let biases = vb.get_with_hints(
295 (num_local_experts, out_dim, in_dim / group_size),
296 "biases",
297 Default::default(),
298 )?;
299
300 let bias = if bias {
301 Some(vb.get((num_local_experts, out_dim), "bias")?)
302 } else {
303 None
304 };
305
306 Ok(Arc::new(Self {
307 w_q,
308 scales,
309 bias,
310 biases,
311 bits: AfqBits::try_from(*bits)?,
312 group_size: AfqGroupSize::try_from(*group_size)?,
313 }))
314 }
315}
316
317impl QuantizedSerde for AfqLayer {
318 fn name(&self) -> &'static str {
319 "afq-layer"
320 }
321 fn isq_serde_supported(&self) -> bool {
322 true
323 }
324 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
325 let mut buffer = Vec::new();
326
327 buffer.extend(&UQFF_VERSION.to_le_bytes());
329
330 buffer.push(QuantizedSerdeType::Afq as u8);
332
333 buffer.push(bias.is_some() as u8);
335
336 serialize_tensor(&mut buffer, &self.w_q)?;
338 serialize_tensor(&mut buffer, &self.scales)?;
339 serialize_tensor(&mut buffer, &self.biases)?;
340
341 buffer.push(self.bits as u8);
343 buffer.push(self.group_size as u8);
344
345 if let Some(bias) = &bias {
346 serialize_tensor(&mut buffer, bias)?;
348 }
349
350 Ok(Cow::from(buffer))
351 }
352 fn deserialize(
353 data: Cow<[u8]>,
354 device: &Device,
355 _comm: &Arc<Comm>,
356 guard: QuantizeOntoGuard,
357 ) -> Result<Arc<dyn QuantMethod>>
358 where
359 Self: Sized,
360 {
361 let mut buffer = Cursor::new(data.to_vec());
362
363 let version = buffer.read_u32::<LittleEndian>()?;
364 if let Err(e) = version_is_compatible(version) {
365 return Err(candle_core::Error::wrap(e));
366 }
367
368 let isq_type = buffer.read_u8()? as usize;
369 if isq_type != QuantizedSerdeType::Afq as usize {
370 candle_core::bail!(
371 "ISQ type ({isq_type}) doesn't match expected type {}",
372 QuantizedSerdeType::Afq as usize
373 );
374 }
375
376 let has_bias = buffer.read_u8()? != 0;
377
378 let _acquired_load_guard = guard.acquire();
379 let w_q = deserialize_tensor(&mut buffer, device)?;
381 let scales = deserialize_tensor(&mut buffer, device)?;
382 let biases = deserialize_tensor(&mut buffer, device)?;
383
384 let bits: AfqBits = buffer.read_u8()?.try_into()?;
386 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
387
388 let b = if has_bias {
389 Some(deserialize_tensor(&mut buffer, device)?)
390 } else {
391 None
392 };
393
394 Ok(Arc::new(Self {
395 w_q,
396 scales,
397 bias: b,
398 biases,
399 bits,
400 group_size,
401 }))
402 }
403 fn deserialize_ext_bias(
404 data: Cow<[u8]>,
405 device: &Device,
406 guard: QuantizeOntoGuard,
407 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
408 where
409 Self: Sized,
410 {
411 let mut buffer = Cursor::new(data.to_vec());
412
413 let version = buffer.read_u32::<LittleEndian>()?;
414 if let Err(e) = version_is_compatible(version) {
415 return Err(candle_core::Error::wrap(e));
416 }
417
418 let isq_type = buffer.read_u8()? as usize;
419 if isq_type != QuantizedSerdeType::Afq as usize {
420 candle_core::bail!(
421 "ISQ type ({isq_type}) doesn't match expected type {}",
422 QuantizedSerdeType::Afq as usize
423 );
424 }
425
426 let has_bias = buffer.read_u8()? != 0;
427
428 let _acquired_load_guard = guard.acquire();
429 let w_q = deserialize_tensor(&mut buffer, device)?;
431 let scales = deserialize_tensor(&mut buffer, device)?;
432 let biases = deserialize_tensor(&mut buffer, device)?;
433
434 let bits: AfqBits = buffer.read_u8()?.try_into()?;
436 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
437
438 let b = if has_bias {
439 Some(deserialize_tensor(&mut buffer, device)?)
440 } else {
441 None
442 };
443
444 Ok((
445 Arc::new(Self {
446 w_q,
447 scales,
448 bias: None,
449 biases,
450 bits,
451 group_size,
452 }),
453 b,
454 ))
455 }
456}