1use std::{
2 borrow::Cow,
3 io::{Cursor, Read},
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9 quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10 DType, Device, Result, Tensor,
11};
12use candle_nn::{Linear, Module};
13
14use crate::{
15 generate_isq, generate_isq_imatrix,
16 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18 UnquantLinear,
19};
20
21#[derive(Debug)]
22pub struct GgufMatMul {
23 pub(crate) w: QMatMul,
24 pub(crate) b: Option<Tensor>,
25}
26
27impl QuantMethod for GgufMatMul {
28 fn new(method: QuantMethodConfig) -> Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
34 w: QMatMul::from_arc(q_weight)?,
35 b,
36 }),
37 QuantMethodConfig::GptqAwq { .. }
38 | QuantMethodConfig::Unquantized(_)
39 | QuantMethodConfig::Hqq { .. }
40 | QuantMethodConfig::Dummy
41 | QuantMethodConfig::FP8 { .. }
42 | QuantMethodConfig::Bnb { .. }
43 | QuantMethodConfig::BlockwiseFP8 { .. }
44 | QuantMethodConfig::Afq { .. }
45 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
46 }
47 }
48
49 fn dequantize_w(&self) -> Result<Tensor> {
50 self.w.dequantize_f16()?.to_dtype(DType::F32)
51 }
52
53 fn forward(&self, a: &Tensor) -> Result<Tensor> {
54 let x = self.w.forward(a)?;
55 if let Some(ref b) = self.b {
56 x.broadcast_add(b)
57 } else {
58 Ok(x)
59 }
60 }
61
62 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
67 let weight = self.dequantize_w()?;
70 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
72 weight,
73 self.b.clone(),
74 )))?;
75 unquant.gather_forward(x, indices)
76 }
77
78 fn quantized_act_type(&self) -> Option<DType> {
79 Some(DType::F32)
80 }
81
82 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
83 match self {
84 Self {
85 w: QMatMul::Tensor(w),
86 b,
87 } => Ok(Arc::new(Self {
88 w: QMatMul::Tensor((w + delta)?),
89 b: b.clone(),
90 })),
91 Self {
92 w: QMatMul::TensorF16(w),
93 b,
94 } => Ok(Arc::new(Self {
95 w: QMatMul::TensorF16((w + delta)?),
96 b: b.clone(),
97 })),
98 Self {
99 w: QMatMul::QTensor(w),
100 b,
101 } => {
102 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
103 let w = QMatMul::QTensor(std::sync::Arc::new(
104 candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
105 ));
106 Ok(Arc::new(Self { w, b: b.clone() }))
107 }
108 }
109 }
110
111 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
112 match &self.w {
113 QMatMul::QTensor(q) => (DType::F32, q.device()),
114 QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
115 }
116 }
117
118 fn apply_isq(
119 self: Arc<Self>,
120 dtype: Option<IsqType>,
121 device: Device,
122 n_quantized: &AtomicUsize,
123 imatrix_weight: Option<Vec<f32>>,
124 guard: QuantizeOntoGuard,
125 ) -> Result<Arc<dyn QuantMethod>> {
126 if let Some(dtype) = dtype {
127 let t = match &self.w {
128 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
129 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
130 };
131 let dtype = dtype.try_into()?;
132 let res = if let Some(imatrix_weight) = imatrix_weight {
133 generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
134 } else {
135 generate_isq!(t, device, dtype, n_quantized, guard)
136 };
137 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
138 q_weight: res,
139 b: self.b.clone(),
140 })?))
141 } else {
142 let w = match &self.w {
143 QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
144 &q.dequantize(&device)?,
145 q.dtype(),
146 )?)),
147 QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
148 QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
149 };
150 let b = if let Some(b) = &self.b {
151 Some(b.to_device(&device)?)
152 } else {
153 None
154 };
155 Ok(Arc::new(GgufMatMul { w, b }))
156 }
157 }
158}
159
160impl QuantizedSerde for GgufMatMul {
187 fn isq_serde_supported(&self) -> bool {
188 true
189 }
190 fn name(&self) -> &'static str {
191 "gguf"
192 }
193 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
194 self.serialize_with_bias(self.b.clone())
195 }
196 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
197 let mut buffer = match &self.w {
198 QMatMul::QTensor(qw) => {
199 let w = qw.data()?.to_vec();
200 let w_shape = qw.shape().dims();
201 let dtype: u32 = match qw.dtype() {
202 GgmlDType::F32 => 0,
203 GgmlDType::F16 => 1,
204 GgmlDType::Q4_0 => 2,
205 GgmlDType::Q4_1 => 3,
206 GgmlDType::Q5_0 => 6,
207 GgmlDType::Q5_1 => 7,
208 GgmlDType::Q8_0 => 8,
209 GgmlDType::Q8_1 => 9,
210 GgmlDType::Q2K => 10,
211 GgmlDType::Q3K => 11,
212 GgmlDType::Q4K => 12,
213 GgmlDType::Q5K => 13,
214 GgmlDType::Q6K => 14,
215 GgmlDType::Q8K => 15,
216 GgmlDType::BF16 => 30,
218 };
219
220 let mut buffer = Vec::new();
221
222 buffer.extend(&UQFF_VERSION.to_le_bytes());
224
225 buffer.push(QuantizedSerdeType::Gguf as u8);
227
228 buffer.extend(&(w.len() as u32).to_le_bytes());
230
231 buffer.push(bias.is_some() as u8);
233
234 buffer.extend(&dtype.to_le_bytes());
236
237 buffer.extend((w_shape.len() as u32).to_le_bytes());
239 for dim in w_shape {
240 buffer.extend((*dim as u32).to_le_bytes());
241 }
242
243 buffer.extend(&w);
245
246 buffer
247 }
248 QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
249 candle_core::bail!("Cannot serialize non-quantized")
250 }
251 };
252
253 if let Some(b) = bias.as_ref() {
254 serialize_tensor(&mut buffer, b)?;
255 }
256
257 Ok(Cow::from(buffer))
258 }
259
260 fn deserialize(
261 data: Cow<[u8]>,
262 device: &Device,
263 _comm: &Arc<crate::Comm>,
264 guard: QuantizeOntoGuard,
265 ) -> Result<Arc<dyn QuantMethod>> {
266 let mut buffer = Cursor::new(data);
267
268 let version = buffer.read_u32::<LittleEndian>()?;
269 if let Err(e) = version_is_compatible(version) {
270 return Err(candle_core::Error::wrap(e));
271 }
272
273 let isq_type = buffer.read_u8()? as usize;
274 if isq_type != QuantizedSerdeType::Gguf as usize {
275 candle_core::bail!(
276 "ISQ type ({isq_type}) doesn't match expected type {}",
277 QuantizedSerdeType::Gguf as usize
278 );
279 }
280
281 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
282
283 let has_bias = buffer.read_u8()? != 0;
284
285 let dtype = buffer.read_u32::<LittleEndian>()?;
287 let dtype = match dtype {
288 0 => GgmlDType::F32,
289 1 => GgmlDType::F16,
290 2 => GgmlDType::Q4_0,
291 3 => GgmlDType::Q4_1,
292 6 => GgmlDType::Q5_0,
293 7 => GgmlDType::Q5_1,
294 8 => GgmlDType::Q8_0,
295 9 => GgmlDType::Q8_1,
296 10 => GgmlDType::Q2K,
297 11 => GgmlDType::Q3K,
298 12 => GgmlDType::Q4K,
299 13 => GgmlDType::Q5K,
300 14 => GgmlDType::Q6K,
301 15 => GgmlDType::Q8K,
302 30 => GgmlDType::BF16,
304 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
305 };
306
307 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
308
309 let mut dims = Vec::with_capacity(n_dims);
310 for _ in 0..n_dims {
311 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
312 }
313
314 let mut tensor_data = vec![0; data_len];
315 buffer.read_exact(&mut tensor_data)?;
316
317 let _acquired_load_guard = guard.acquire(device);
318 let b = if has_bias {
320 Some(deserialize_tensor(&mut buffer, device)?)
321 } else {
322 None
323 };
324
325 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
326 Ok(Arc::new(Self {
327 w: QMatMul::QTensor(w.into()),
328 b,
329 }))
330 }
331 fn deserialize_ext_bias(
332 data: Cow<[u8]>,
333 device: &Device,
334 guard: QuantizeOntoGuard,
335 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
336 let mut buffer = Cursor::new(data);
337
338 let version = buffer.read_u32::<LittleEndian>()?;
339 if let Err(e) = version_is_compatible(version) {
340 return Err(candle_core::Error::wrap(e));
341 }
342
343 let isq_type = buffer.read_u8()? as usize;
344 if isq_type != QuantizedSerdeType::Gguf as usize {
345 candle_core::bail!(
346 "ISQ type ({isq_type}) doesn't match expected type {}",
347 QuantizedSerdeType::Gguf as usize
348 );
349 }
350
351 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
352
353 let has_bias = buffer.read_u8()? != 0;
354
355 let dtype = buffer.read_u32::<LittleEndian>()?;
357 let dtype = match dtype {
358 0 => GgmlDType::F32,
359 1 => GgmlDType::F16,
360 2 => GgmlDType::Q4_0,
361 3 => GgmlDType::Q4_1,
362 6 => GgmlDType::Q5_0,
363 7 => GgmlDType::Q5_1,
364 8 => GgmlDType::Q8_0,
365 9 => GgmlDType::Q8_1,
366 10 => GgmlDType::Q2K,
367 11 => GgmlDType::Q3K,
368 12 => GgmlDType::Q4K,
369 13 => GgmlDType::Q5K,
370 14 => GgmlDType::Q6K,
371 15 => GgmlDType::Q8K,
372 30 => GgmlDType::BF16,
374 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
375 };
376
377 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
378
379 let mut dims = Vec::with_capacity(n_dims);
380 for _ in 0..n_dims {
381 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
382 }
383
384 let mut tensor_data = vec![0; data_len];
385 buffer.read_exact(&mut tensor_data)?;
386
387 let _acquired_load_guard = guard.acquire(device);
388 let b = if has_bias {
390 Some(deserialize_tensor(&mut buffer, device)?)
391 } else {
392 None
393 };
394
395 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
396 Ok((
397 Arc::new(Self {
398 w: QMatMul::QTensor(w.into()),
399 b: None,
400 }),
401 b,
402 ))
403 }
404}
405
406impl GgufMatMul {
407 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
408 let mut buffer = Cursor::new(data);
409
410 let version = buffer.read_u32::<LittleEndian>()?;
411 if let Err(e) = version_is_compatible(version) {
412 return Err(candle_core::Error::wrap(e));
413 }
414
415 let isq_type = buffer.read_u8()? as usize;
416 if isq_type != QuantizedSerdeType::Gguf as usize {
417 candle_core::bail!(
418 "ISQ type ({isq_type}) doesn't match expected type {}",
419 QuantizedSerdeType::Gguf as usize
420 );
421 }
422
423 let _ = buffer.read_u32::<LittleEndian>()? as usize;
424
425 let _ = buffer.read_u8()? != 0;
426
427 let dtype = buffer.read_u32::<LittleEndian>()?;
428 let dtype = match dtype {
429 0 => GgmlDType::F32,
430 1 => GgmlDType::F16,
431 2 => GgmlDType::Q4_0,
432 3 => GgmlDType::Q4_1,
433 6 => GgmlDType::Q5_0,
434 7 => GgmlDType::Q5_1,
435 8 => GgmlDType::Q8_0,
436 9 => GgmlDType::Q8_1,
437 10 => GgmlDType::Q2K,
438 11 => GgmlDType::Q3K,
439 12 => GgmlDType::Q4K,
440 13 => GgmlDType::Q5K,
441 14 => GgmlDType::Q6K,
442 15 => GgmlDType::Q8K,
443 30 => GgmlDType::BF16,
445 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
446 };
447
448 IsqType::try_from(dtype)
449 }
450}