1use std::{
2 borrow::Cow,
3 io::{Cursor, Read},
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{
9 quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
10 DType, Device, Result, Tensor,
11};
12use candle_nn::{Linear, Module};
13
14use crate::{
15 generate_isq, generate_isq_imatrix,
16 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
17 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18 UnquantLinear,
19};
20
21#[derive(Debug)]
22pub struct GgufMatMul {
23 pub(crate) w: QMatMul,
24 pub(crate) b: Option<Tensor>,
25}
26
27impl QuantMethod for GgufMatMul {
28 fn new(method: QuantMethodConfig) -> Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
34 w: QMatMul::from_arc(q_weight)?,
35 b,
36 }),
37 QuantMethodConfig::Gptq { .. }
38 | QuantMethodConfig::Unquantized(_)
39 | QuantMethodConfig::Hqq { .. }
40 | QuantMethodConfig::Dummy
41 | QuantMethodConfig::FP8 { .. }
42 | QuantMethodConfig::Bnb { .. }
43 | QuantMethodConfig::BlockwiseFP8 { .. }
44 | QuantMethodConfig::Afq { .. } => unreachable!(),
45 }
46 }
47
48 fn dequantize_w(&self) -> Result<Tensor> {
49 self.w.dequantize_f16()?.to_dtype(DType::F32)
50 }
51
52 fn forward(&self, a: &Tensor) -> Result<Tensor> {
53 let x = self.w.forward(a)?;
54 if let Some(ref b) = self.b {
55 x.broadcast_add(b)
56 } else {
57 Ok(x)
58 }
59 }
60
61 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
66 let weight = self.dequantize_w()?;
69 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
71 weight,
72 self.b.clone(),
73 )))?;
74 unquant.gather_forward(x, indices)
75 }
76
77 fn quantized_act_type(&self) -> Option<DType> {
78 Some(DType::F32)
79 }
80
81 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
82 match self {
83 Self {
84 w: QMatMul::Tensor(w),
85 b,
86 } => Ok(Arc::new(Self {
87 w: QMatMul::Tensor((w + delta)?),
88 b: b.clone(),
89 })),
90 Self {
91 w: QMatMul::TensorF16(w),
92 b,
93 } => Ok(Arc::new(Self {
94 w: QMatMul::TensorF16((w + delta)?),
95 b: b.clone(),
96 })),
97 Self {
98 w: QMatMul::QTensor(w),
99 b,
100 } => {
101 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
102 let w = QMatMul::QTensor(std::sync::Arc::new(
103 candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
104 ));
105 Ok(Arc::new(Self { w, b: b.clone() }))
106 }
107 }
108 }
109
110 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
111 match &self.w {
112 QMatMul::QTensor(q) => (DType::F32, q.device()),
113 QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
114 }
115 }
116
117 fn apply_isq(
118 self: Arc<Self>,
119 dtype: Option<IsqType>,
120 device: Device,
121 n_quantized: &AtomicUsize,
122 imatrix_weight: Option<Vec<f32>>,
123 guard: QuantizeOntoGuard,
124 ) -> Result<Arc<dyn QuantMethod>> {
125 if let Some(dtype) = dtype {
126 let t = match &self.w {
127 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
128 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
129 };
130 let dtype = dtype.try_into()?;
131 let res = if let Some(imatrix_weight) = imatrix_weight {
132 generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
133 } else {
134 generate_isq!(t, device, dtype, n_quantized, guard)
135 };
136 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
137 q_weight: res,
138 b: self.b.clone(),
139 })?))
140 } else {
141 let w = match &self.w {
142 QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
143 &q.dequantize(&device)?,
144 q.dtype(),
145 )?)),
146 QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
147 QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
148 };
149 let b = if let Some(b) = &self.b {
150 Some(b.to_device(&device)?)
151 } else {
152 None
153 };
154 Ok(Arc::new(GgufMatMul { w, b }))
155 }
156 }
157}
158
159impl QuantizedSerde for GgufMatMul {
186 fn isq_serde_supported(&self) -> bool {
187 true
188 }
189 fn name(&self) -> &'static str {
190 "gguf"
191 }
192 fn serialize(&self) -> Result<Cow<[u8]>> {
193 self.serialize_with_bias(self.b.clone())
194 }
195 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
196 let mut buffer = match &self.w {
197 QMatMul::QTensor(qw) => {
198 let w = qw.data()?.to_vec();
199 let w_shape = qw.shape().dims();
200 let dtype: u32 = match qw.dtype() {
201 GgmlDType::F32 => 0,
202 GgmlDType::F16 => 1,
203 GgmlDType::Q4_0 => 2,
204 GgmlDType::Q4_1 => 3,
205 GgmlDType::Q5_0 => 6,
206 GgmlDType::Q5_1 => 7,
207 GgmlDType::Q8_0 => 8,
208 GgmlDType::Q8_1 => 9,
209 GgmlDType::Q2K => 10,
210 GgmlDType::Q3K => 11,
211 GgmlDType::Q4K => 12,
212 GgmlDType::Q5K => 13,
213 GgmlDType::Q6K => 14,
214 GgmlDType::Q8K => 15,
215 GgmlDType::BF16 => 30,
217 };
218
219 let mut buffer = Vec::new();
220
221 buffer.extend(&UQFF_VERSION.to_le_bytes());
223
224 buffer.push(QuantizedSerdeType::Gguf as u8);
226
227 buffer.extend(&(w.len() as u32).to_le_bytes());
229
230 buffer.push(bias.is_some() as u8);
232
233 buffer.extend(&dtype.to_le_bytes());
235
236 buffer.extend((w_shape.len() as u32).to_le_bytes());
238 for dim in w_shape {
239 buffer.extend((*dim as u32).to_le_bytes());
240 }
241
242 buffer.extend(&w);
244
245 buffer
246 }
247 QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
248 candle_core::bail!("Cannot serialize non-quantized")
249 }
250 };
251
252 if let Some(b) = bias.as_ref() {
253 serialize_tensor(&mut buffer, b)?;
254 }
255
256 Ok(Cow::from(buffer))
257 }
258
259 fn deserialize(
260 data: Cow<[u8]>,
261 device: &Device,
262 _comm: &Arc<crate::Comm>,
263 guard: QuantizeOntoGuard,
264 ) -> Result<Arc<dyn QuantMethod>> {
265 let mut buffer = Cursor::new(data);
266
267 let version = buffer.read_u32::<LittleEndian>()?;
268 if let Err(e) = version_is_compatible(version) {
269 return Err(candle_core::Error::wrap(e));
270 }
271
272 let isq_type = buffer.read_u8()? as usize;
273 if isq_type != QuantizedSerdeType::Gguf as usize {
274 candle_core::bail!(
275 "ISQ type ({isq_type}) doesn't match expected type {}",
276 QuantizedSerdeType::Gguf as usize
277 );
278 }
279
280 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
281
282 let has_bias = buffer.read_u8()? != 0;
283
284 let dtype = buffer.read_u32::<LittleEndian>()?;
286 let dtype = match dtype {
287 0 => GgmlDType::F32,
288 1 => GgmlDType::F16,
289 2 => GgmlDType::Q4_0,
290 3 => GgmlDType::Q4_1,
291 6 => GgmlDType::Q5_0,
292 7 => GgmlDType::Q5_1,
293 8 => GgmlDType::Q8_0,
294 9 => GgmlDType::Q8_1,
295 10 => GgmlDType::Q2K,
296 11 => GgmlDType::Q3K,
297 12 => GgmlDType::Q4K,
298 13 => GgmlDType::Q5K,
299 14 => GgmlDType::Q6K,
300 15 => GgmlDType::Q8K,
301 30 => GgmlDType::BF16,
303 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
304 };
305
306 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
307
308 let mut dims = Vec::with_capacity(n_dims);
309 for _ in 0..n_dims {
310 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
311 }
312
313 let mut tensor_data = vec![0; data_len];
314 buffer.read_exact(&mut tensor_data)?;
315
316 let _acquired_load_guard = guard.acquire();
317 let b = if has_bias {
319 Some(deserialize_tensor(&mut buffer, device)?)
320 } else {
321 None
322 };
323
324 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
325 Ok(Arc::new(Self {
326 w: QMatMul::QTensor(w.into()),
327 b,
328 }))
329 }
330 fn deserialize_ext_bias(
331 data: Cow<[u8]>,
332 device: &Device,
333 guard: QuantizeOntoGuard,
334 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
335 let mut buffer = Cursor::new(data);
336
337 let version = buffer.read_u32::<LittleEndian>()?;
338 if let Err(e) = version_is_compatible(version) {
339 return Err(candle_core::Error::wrap(e));
340 }
341
342 let isq_type = buffer.read_u8()? as usize;
343 if isq_type != QuantizedSerdeType::Gguf as usize {
344 candle_core::bail!(
345 "ISQ type ({isq_type}) doesn't match expected type {}",
346 QuantizedSerdeType::Gguf as usize
347 );
348 }
349
350 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
351
352 let has_bias = buffer.read_u8()? != 0;
353
354 let dtype = buffer.read_u32::<LittleEndian>()?;
356 let dtype = match dtype {
357 0 => GgmlDType::F32,
358 1 => GgmlDType::F16,
359 2 => GgmlDType::Q4_0,
360 3 => GgmlDType::Q4_1,
361 6 => GgmlDType::Q5_0,
362 7 => GgmlDType::Q5_1,
363 8 => GgmlDType::Q8_0,
364 9 => GgmlDType::Q8_1,
365 10 => GgmlDType::Q2K,
366 11 => GgmlDType::Q3K,
367 12 => GgmlDType::Q4K,
368 13 => GgmlDType::Q5K,
369 14 => GgmlDType::Q6K,
370 15 => GgmlDType::Q8K,
371 30 => GgmlDType::BF16,
373 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
374 };
375
376 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
377
378 let mut dims = Vec::with_capacity(n_dims);
379 for _ in 0..n_dims {
380 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
381 }
382
383 let mut tensor_data = vec![0; data_len];
384 buffer.read_exact(&mut tensor_data)?;
385
386 let _acquired_load_guard = guard.acquire();
387 let b = if has_bias {
389 Some(deserialize_tensor(&mut buffer, device)?)
390 } else {
391 None
392 };
393
394 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
395 Ok((
396 Arc::new(Self {
397 w: QMatMul::QTensor(w.into()),
398 b: None,
399 }),
400 b,
401 ))
402 }
403}
404
405impl GgufMatMul {
406 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
407 let mut buffer = Cursor::new(data);
408
409 let version = buffer.read_u32::<LittleEndian>()?;
410 if let Err(e) = version_is_compatible(version) {
411 return Err(candle_core::Error::wrap(e));
412 }
413
414 let isq_type = buffer.read_u8()? as usize;
415 if isq_type != QuantizedSerdeType::Gguf as usize {
416 candle_core::bail!(
417 "ISQ type ({isq_type}) doesn't match expected type {}",
418 QuantizedSerdeType::Gguf as usize
419 );
420 }
421
422 let _ = buffer.read_u32::<LittleEndian>()? as usize;
423
424 let _ = buffer.read_u8()? != 0;
425
426 let dtype = buffer.read_u32::<LittleEndian>()?;
427 let dtype = match dtype {
428 0 => GgmlDType::F32,
429 1 => GgmlDType::F16,
430 2 => GgmlDType::Q4_0,
431 3 => GgmlDType::Q4_1,
432 6 => GgmlDType::Q5_0,
433 7 => GgmlDType::Q5_1,
434 8 => GgmlDType::Q8_0,
435 9 => GgmlDType::Q8_1,
436 10 => GgmlDType::Q2K,
437 11 => GgmlDType::Q3K,
438 12 => GgmlDType::Q4K,
439 13 => GgmlDType::Q5K,
440 14 => GgmlDType::Q6K,
441 15 => GgmlDType::Q8K,
442 30 => GgmlDType::BF16,
444 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
445 };
446
447 IsqType::try_from(dtype)
448 }
449}