1use std::{
2 borrow::Cow,
3 io::{Cursor, Read},
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9 quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10 DType, Device, Result, Tensor,
11};
12use candle_nn::Module;
13
14use crate::{
15 generate_isq, generate_isq_imatrix,
16 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct GgufMatMul {
22 pub(crate) w: QMatMul,
23 pub(crate) b: Option<Tensor>,
24}
25
26impl QuantMethod for GgufMatMul {
27 fn new(method: QuantMethodConfig) -> Result<Self>
28 where
29 Self: Sized,
30 {
31 match method {
32 QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
33 w: QMatMul::from_arc(q_weight)?,
34 b,
35 }),
36 QuantMethodConfig::GptqAwq { .. }
37 | QuantMethodConfig::Unquantized(_)
38 | QuantMethodConfig::Hqq { .. }
39 | QuantMethodConfig::Dummy
40 | QuantMethodConfig::FP8 { .. }
41 | QuantMethodConfig::Bnb { .. }
42 | QuantMethodConfig::BlockwiseFP8 { .. }
43 | QuantMethodConfig::Afq { .. }
44 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
45 }
46 }
47
48 fn dequantize_w(&self) -> Result<Tensor> {
49 self.w.dequantize_f16()?.to_dtype(DType::F32)
50 }
51
52 fn forward(&self, a: &Tensor) -> Result<Tensor> {
53 let x = self.w.forward(a)?;
54 if let Some(ref b) = self.b {
55 x.broadcast_add(b)
56 } else {
57 Ok(x)
58 }
59 }
60
61 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
66 let res = self.w.indexed_moe_forward(x, indices)?;
72 if let Some(ref b) = self.b {
73 res.broadcast_add(b)
74 } else {
75 Ok(res)
76 }
77 }
78
79 fn quantized_act_type(&self) -> Option<DType> {
80 Some(DType::F32)
81 }
82
83 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
84 match self {
85 Self {
86 w: QMatMul::Tensor(w),
87 b,
88 } => Ok(Arc::new(Self {
89 w: QMatMul::Tensor((w + delta)?),
90 b: b.clone(),
91 })),
92 Self {
93 w: QMatMul::TensorF16(w),
94 b,
95 } => Ok(Arc::new(Self {
96 w: QMatMul::TensorF16((w + delta)?),
97 b: b.clone(),
98 })),
99 Self {
100 w: QMatMul::QTensor(w),
101 b,
102 } => {
103 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
104 let w = QMatMul::QTensor(std::sync::Arc::new(
105 candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
106 ));
107 Ok(Arc::new(Self { w, b: b.clone() }))
108 }
109 }
110 }
111
112 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
113 match &self.w {
114 QMatMul::QTensor(q) => (DType::F32, q.device()),
115 QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
116 }
117 }
118
119 fn apply_isq(
120 self: Arc<Self>,
121 dtype: Option<IsqType>,
122 device: Device,
123 n_quantized: &AtomicUsize,
124 imatrix_weight: Option<Vec<f32>>,
125 guard: QuantizeOntoGuard,
126 ) -> Result<Arc<dyn QuantMethod>> {
127 if let Some(dtype) = dtype {
128 let t = match &self.w {
129 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
130 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
131 };
132 let dtype = dtype.try_into()?;
133 let res = if let Some(imatrix_weight) = imatrix_weight {
134 generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
135 } else {
136 generate_isq!(t, device, dtype, n_quantized, guard)
137 };
138 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
139 q_weight: res,
140 b: self.b.clone(),
141 })?))
142 } else {
143 let w = match &self.w {
144 QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
145 &q.dequantize(&device)?,
146 q.dtype(),
147 )?)),
148 QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
149 QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
150 };
151 let b = if let Some(b) = &self.b {
152 Some(b.to_device(&device)?)
153 } else {
154 None
155 };
156 Ok(Arc::new(GgufMatMul { w, b }))
157 }
158 }
159}
160
161impl QuantizedSerde for GgufMatMul {
188 fn isq_serde_supported(&self) -> bool {
189 true
190 }
191 fn name(&self) -> &'static str {
192 "gguf"
193 }
194 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
195 self.serialize_with_bias(self.b.clone())
196 }
197 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
198 let mut buffer = match &self.w {
199 QMatMul::QTensor(qw) => {
200 let w = qw.data()?.to_vec();
201 let w_shape = qw.shape().dims();
202 let dtype: u32 = match qw.dtype() {
203 GgmlDType::F32 => 0,
204 GgmlDType::F16 => 1,
205 GgmlDType::Q4_0 => 2,
206 GgmlDType::Q4_1 => 3,
207 GgmlDType::Q5_0 => 6,
208 GgmlDType::Q5_1 => 7,
209 GgmlDType::Q8_0 => 8,
210 GgmlDType::Q8_1 => 9,
211 GgmlDType::Q2K => 10,
212 GgmlDType::Q3K => 11,
213 GgmlDType::Q4K => 12,
214 GgmlDType::Q5K => 13,
215 GgmlDType::Q6K => 14,
216 GgmlDType::Q8K => 15,
217 GgmlDType::BF16 => 30,
219 };
220
221 let mut buffer = Vec::new();
222
223 buffer.extend(&UQFF_VERSION.to_le_bytes());
225
226 buffer.push(QuantizedSerdeType::Gguf as u8);
228
229 buffer.extend(&(w.len() as u32).to_le_bytes());
231
232 buffer.push(bias.is_some() as u8);
234
235 buffer.extend(&dtype.to_le_bytes());
237
238 buffer.extend((w_shape.len() as u32).to_le_bytes());
240 for dim in w_shape {
241 buffer.extend((*dim as u32).to_le_bytes());
242 }
243
244 buffer.extend(&w);
246
247 buffer
248 }
249 QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
250 candle_core::bail!("Cannot serialize non-quantized")
251 }
252 };
253
254 if let Some(b) = bias.as_ref() {
255 serialize_tensor(&mut buffer, b)?;
256 }
257
258 Ok(Cow::from(buffer))
259 }
260
261 fn deserialize(
262 data: Cow<[u8]>,
263 device: &Device,
264 _comm: &Arc<crate::Comm>,
265 guard: QuantizeOntoGuard,
266 ) -> Result<Arc<dyn QuantMethod>> {
267 let mut buffer = Cursor::new(data);
268
269 let version = buffer.read_u32::<LittleEndian>()?;
270 if let Err(e) = version_is_compatible(version) {
271 return Err(candle_core::Error::wrap(e));
272 }
273
274 let isq_type = buffer.read_u8()? as usize;
275 if isq_type != QuantizedSerdeType::Gguf as usize {
276 candle_core::bail!(
277 "ISQ type ({isq_type}) doesn't match expected type {}",
278 QuantizedSerdeType::Gguf as usize
279 );
280 }
281
282 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
283
284 let has_bias = buffer.read_u8()? != 0;
285
286 let dtype = buffer.read_u32::<LittleEndian>()?;
288 let dtype = match dtype {
289 0 => GgmlDType::F32,
290 1 => GgmlDType::F16,
291 2 => GgmlDType::Q4_0,
292 3 => GgmlDType::Q4_1,
293 6 => GgmlDType::Q5_0,
294 7 => GgmlDType::Q5_1,
295 8 => GgmlDType::Q8_0,
296 9 => GgmlDType::Q8_1,
297 10 => GgmlDType::Q2K,
298 11 => GgmlDType::Q3K,
299 12 => GgmlDType::Q4K,
300 13 => GgmlDType::Q5K,
301 14 => GgmlDType::Q6K,
302 15 => GgmlDType::Q8K,
303 30 => GgmlDType::BF16,
305 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
306 };
307
308 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
309
310 let mut dims = Vec::with_capacity(n_dims);
311 for _ in 0..n_dims {
312 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
313 }
314
315 let mut tensor_data = vec![0; data_len];
316 buffer.read_exact(&mut tensor_data)?;
317
318 let _acquired_load_guard = guard.acquire(device);
319 let b = if has_bias {
321 Some(deserialize_tensor(&mut buffer, device)?)
322 } else {
323 None
324 };
325
326 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
327 Ok(Arc::new(Self {
328 w: QMatMul::QTensor(w.into()),
329 b,
330 }))
331 }
332 fn deserialize_ext_bias(
333 data: Cow<[u8]>,
334 device: &Device,
335 guard: QuantizeOntoGuard,
336 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
337 let mut buffer = Cursor::new(data);
338
339 let version = buffer.read_u32::<LittleEndian>()?;
340 if let Err(e) = version_is_compatible(version) {
341 return Err(candle_core::Error::wrap(e));
342 }
343
344 let isq_type = buffer.read_u8()? as usize;
345 if isq_type != QuantizedSerdeType::Gguf as usize {
346 candle_core::bail!(
347 "ISQ type ({isq_type}) doesn't match expected type {}",
348 QuantizedSerdeType::Gguf as usize
349 );
350 }
351
352 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
353
354 let has_bias = buffer.read_u8()? != 0;
355
356 let dtype = buffer.read_u32::<LittleEndian>()?;
358 let dtype = match dtype {
359 0 => GgmlDType::F32,
360 1 => GgmlDType::F16,
361 2 => GgmlDType::Q4_0,
362 3 => GgmlDType::Q4_1,
363 6 => GgmlDType::Q5_0,
364 7 => GgmlDType::Q5_1,
365 8 => GgmlDType::Q8_0,
366 9 => GgmlDType::Q8_1,
367 10 => GgmlDType::Q2K,
368 11 => GgmlDType::Q3K,
369 12 => GgmlDType::Q4K,
370 13 => GgmlDType::Q5K,
371 14 => GgmlDType::Q6K,
372 15 => GgmlDType::Q8K,
373 30 => GgmlDType::BF16,
375 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
376 };
377
378 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
379
380 let mut dims = Vec::with_capacity(n_dims);
381 for _ in 0..n_dims {
382 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
383 }
384
385 let mut tensor_data = vec![0; data_len];
386 buffer.read_exact(&mut tensor_data)?;
387
388 let _acquired_load_guard = guard.acquire(device);
389 let b = if has_bias {
391 Some(deserialize_tensor(&mut buffer, device)?)
392 } else {
393 None
394 };
395
396 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
397 Ok((
398 Arc::new(Self {
399 w: QMatMul::QTensor(w.into()),
400 b: None,
401 }),
402 b,
403 ))
404 }
405}
406
407impl GgufMatMul {
408 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
409 let mut buffer = Cursor::new(data);
410
411 let version = buffer.read_u32::<LittleEndian>()?;
412 if let Err(e) = version_is_compatible(version) {
413 return Err(candle_core::Error::wrap(e));
414 }
415
416 let isq_type = buffer.read_u8()? as usize;
417 if isq_type != QuantizedSerdeType::Gguf as usize {
418 candle_core::bail!(
419 "ISQ type ({isq_type}) doesn't match expected type {}",
420 QuantizedSerdeType::Gguf as usize
421 );
422 }
423
424 let _ = buffer.read_u32::<LittleEndian>()? as usize;
425
426 let _ = buffer.read_u8()? != 0;
427
428 let dtype = buffer.read_u32::<LittleEndian>()?;
429 let dtype = match dtype {
430 0 => GgmlDType::F32,
431 1 => GgmlDType::F16,
432 2 => GgmlDType::Q4_0,
433 3 => GgmlDType::Q4_1,
434 6 => GgmlDType::Q5_0,
435 7 => GgmlDType::Q5_1,
436 8 => GgmlDType::Q8_0,
437 9 => GgmlDType::Q8_1,
438 10 => GgmlDType::Q2K,
439 11 => GgmlDType::Q3K,
440 12 => GgmlDType::Q4K,
441 13 => GgmlDType::Q5K,
442 14 => GgmlDType::Q6K,
443 15 => GgmlDType::Q8K,
444 30 => GgmlDType::BF16,
446 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
447 };
448
449 IsqType::try_from(dtype)
450 }
451}