1#[cfg(not(feature = "cuda"))]
2mod cpu;
3#[cfg(feature = "cuda")]
4mod cuda;
5#[cfg(feature = "cuda")]
6mod ffi;
7
8use std::{
9 borrow::Cow,
10 io::{Cursor, Read},
11 sync::{atomic::AtomicUsize, Arc},
12};
13
14use byteorder::{LittleEndian, ReadBytesExt};
15use candle_core::{
16 quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
17 DType, Device, Result, Tensor,
18};
19use candle_nn::Module;
20
21use crate::{
22 generate_isq, generate_isq_imatrix,
23 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
24 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
25};
26
27#[derive(Debug)]
28pub struct GgufMatMul {
29 pub(crate) w: QMatMul,
30 pub(crate) b: Option<Tensor>,
31}
32
33impl QuantMethod for GgufMatMul {
34 fn new(method: QuantMethodConfig) -> Result<Self>
35 where
36 Self: Sized,
37 {
38 match method {
39 QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
40 w: QMatMul::from_arc(q_weight)?,
41 b,
42 }),
43 QuantMethodConfig::GptqAwq { .. }
44 | QuantMethodConfig::Unquantized(_)
45 | QuantMethodConfig::Hqq { .. }
46 | QuantMethodConfig::Dummy
47 | QuantMethodConfig::FP8 { .. }
48 | QuantMethodConfig::Bnb { .. }
49 | QuantMethodConfig::BlockwiseFP8 { .. }
50 | QuantMethodConfig::PerTensorFP8 { .. }
51 | QuantMethodConfig::Afq { .. }
52 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
53 }
54 }
55
56 fn dequantize_w(&self) -> Result<Tensor> {
57 self.w.dequantize_f16()?.to_dtype(DType::F32)
58 }
59
60 fn forward(&self, a: &Tensor) -> Result<Tensor> {
61 let x = self.w.forward(a)?;
62 if let Some(ref b) = self.b {
63 x.broadcast_add(b)
64 } else {
65 Ok(x)
66 }
67 }
68
69 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
74 #[cfg(feature = "cuda")]
80 let res = cuda::qmatmul_indexed_moe_forward(&self.w, x, indices)?;
81
82 #[cfg(not(feature = "cuda"))]
84 let res = cpu::cpu_indexed_moe_forward(&self.w, x, indices)?;
85
86 if let Some(ref b) = self.b {
87 res.broadcast_add(b)
88 } else {
89 Ok(res)
90 }
91 }
92
93 fn quantized_act_type(&self) -> Option<DType> {
94 Some(DType::F32)
95 }
96
97 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
98 match self {
99 Self {
100 w: QMatMul::Tensor(w),
101 b,
102 } => Ok(Arc::new(Self {
103 w: QMatMul::Tensor((w + delta)?),
104 b: b.clone(),
105 })),
106 Self {
107 w: QMatMul::TensorF16(w),
108 b,
109 } => Ok(Arc::new(Self {
110 w: QMatMul::TensorF16((w + delta)?),
111 b: b.clone(),
112 })),
113 Self {
114 w: QMatMul::QTensor(w),
115 b,
116 } => {
117 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
118 let w = QMatMul::QTensor(std::sync::Arc::new(
119 candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
120 ));
121 Ok(Arc::new(Self { w, b: b.clone() }))
122 }
123 }
124 }
125
126 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
127 match &self.w {
128 QMatMul::QTensor(q) => (DType::F32, q.device()),
129 QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
130 }
131 }
132
133 fn apply_isq(
134 self: Arc<Self>,
135 dtype: Option<IsqType>,
136 device: Device,
137 n_quantized: &AtomicUsize,
138 imatrix_weight: Option<Vec<f32>>,
139 guard: QuantizeOntoGuard,
140 ) -> Result<Arc<dyn QuantMethod>> {
141 if let Some(dtype) = dtype {
142 let t = match &self.w {
143 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
144 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
145 };
146 let dtype = dtype.try_into()?;
147 let res = if let Some(imatrix_weight) = imatrix_weight {
148 generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
149 } else {
150 generate_isq!(t, device, dtype, n_quantized, guard)
151 };
152 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
153 q_weight: res,
154 b: self.b.clone(),
155 })?))
156 } else {
157 let w = match &self.w {
158 QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
159 &q.dequantize(&device)?,
160 q.dtype(),
161 )?)),
162 QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
163 QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
164 };
165 let b = if let Some(b) = &self.b {
166 Some(b.to_device(&device)?)
167 } else {
168 None
169 };
170 Ok(Arc::new(GgufMatMul { w, b }))
171 }
172 }
173}
174
175impl QuantizedSerde for GgufMatMul {
202 fn isq_serde_supported(&self) -> bool {
203 true
204 }
205 fn name(&self) -> &'static str {
206 "gguf"
207 }
208 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
209 self.serialize_with_bias(self.b.clone())
210 }
211 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
212 let mut buffer = match &self.w {
213 QMatMul::QTensor(qw) => {
214 let w = qw.data()?.to_vec();
215 let w_shape = qw.shape().dims();
216 let dtype: u32 = match qw.dtype() {
217 GgmlDType::F32 => 0,
218 GgmlDType::F16 => 1,
219 GgmlDType::Q4_0 => 2,
220 GgmlDType::Q4_1 => 3,
221 GgmlDType::Q5_0 => 6,
222 GgmlDType::Q5_1 => 7,
223 GgmlDType::Q8_0 => 8,
224 GgmlDType::Q8_1 => 9,
225 GgmlDType::Q2K => 10,
226 GgmlDType::Q3K => 11,
227 GgmlDType::Q4K => 12,
228 GgmlDType::Q5K => 13,
229 GgmlDType::Q6K => 14,
230 GgmlDType::Q8K => 15,
231 GgmlDType::BF16 => 30,
233 };
234
235 let mut buffer = Vec::new();
236
237 buffer.extend(&UQFF_VERSION.to_le_bytes());
239
240 buffer.push(QuantizedSerdeType::Gguf as u8);
242
243 buffer.extend(&(w.len() as u32).to_le_bytes());
245
246 buffer.push(bias.is_some() as u8);
248
249 buffer.extend(&dtype.to_le_bytes());
251
252 buffer.extend((w_shape.len() as u32).to_le_bytes());
254 for dim in w_shape {
255 buffer.extend((*dim as u32).to_le_bytes());
256 }
257
258 buffer.extend(&w);
260
261 buffer
262 }
263 QMatMul::TensorF16(_) | QMatMul::Tensor(_) => {
264 candle_core::bail!("Cannot serialize non-quantized")
265 }
266 };
267
268 if let Some(b) = bias.as_ref() {
269 serialize_tensor(&mut buffer, b)?;
270 }
271
272 Ok(Cow::from(buffer))
273 }
274
275 fn deserialize(
276 data: Cow<[u8]>,
277 device: &Device,
278 _comm: &Arc<crate::Comm>,
279 guard: QuantizeOntoGuard,
280 ) -> Result<Arc<dyn QuantMethod>> {
281 let mut buffer = Cursor::new(data);
282
283 let version = buffer.read_u32::<LittleEndian>()?;
284 if let Err(e) = version_is_compatible(version) {
285 return Err(candle_core::Error::wrap(e));
286 }
287
288 let isq_type = buffer.read_u8()? as usize;
289 if isq_type != QuantizedSerdeType::Gguf as usize {
290 candle_core::bail!(
291 "ISQ type ({isq_type}) doesn't match expected type {}",
292 QuantizedSerdeType::Gguf as usize
293 );
294 }
295
296 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
297
298 let has_bias = buffer.read_u8()? != 0;
299
300 let dtype = buffer.read_u32::<LittleEndian>()?;
302 let dtype = match dtype {
303 0 => GgmlDType::F32,
304 1 => GgmlDType::F16,
305 2 => GgmlDType::Q4_0,
306 3 => GgmlDType::Q4_1,
307 6 => GgmlDType::Q5_0,
308 7 => GgmlDType::Q5_1,
309 8 => GgmlDType::Q8_0,
310 9 => GgmlDType::Q8_1,
311 10 => GgmlDType::Q2K,
312 11 => GgmlDType::Q3K,
313 12 => GgmlDType::Q4K,
314 13 => GgmlDType::Q5K,
315 14 => GgmlDType::Q6K,
316 15 => GgmlDType::Q8K,
317 30 => GgmlDType::BF16,
319 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
320 };
321
322 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
323
324 let mut dims = Vec::with_capacity(n_dims);
325 for _ in 0..n_dims {
326 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
327 }
328
329 let mut tensor_data = vec![0; data_len];
330 buffer.read_exact(&mut tensor_data)?;
331
332 let _acquired_load_guard = guard.acquire(device);
333 let b = if has_bias {
335 Some(deserialize_tensor(&mut buffer, device)?)
336 } else {
337 None
338 };
339
340 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
341 Ok(Arc::new(Self {
342 w: QMatMul::QTensor(w.into()),
343 b,
344 }))
345 }
346 fn deserialize_ext_bias(
347 data: Cow<[u8]>,
348 device: &Device,
349 guard: QuantizeOntoGuard,
350 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
351 let mut buffer = Cursor::new(data);
352
353 let version = buffer.read_u32::<LittleEndian>()?;
354 if let Err(e) = version_is_compatible(version) {
355 return Err(candle_core::Error::wrap(e));
356 }
357
358 let isq_type = buffer.read_u8()? as usize;
359 if isq_type != QuantizedSerdeType::Gguf as usize {
360 candle_core::bail!(
361 "ISQ type ({isq_type}) doesn't match expected type {}",
362 QuantizedSerdeType::Gguf as usize
363 );
364 }
365
366 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
367
368 let has_bias = buffer.read_u8()? != 0;
369
370 let dtype = buffer.read_u32::<LittleEndian>()?;
372 let dtype = match dtype {
373 0 => GgmlDType::F32,
374 1 => GgmlDType::F16,
375 2 => GgmlDType::Q4_0,
376 3 => GgmlDType::Q4_1,
377 6 => GgmlDType::Q5_0,
378 7 => GgmlDType::Q5_1,
379 8 => GgmlDType::Q8_0,
380 9 => GgmlDType::Q8_1,
381 10 => GgmlDType::Q2K,
382 11 => GgmlDType::Q3K,
383 12 => GgmlDType::Q4K,
384 13 => GgmlDType::Q5K,
385 14 => GgmlDType::Q6K,
386 15 => GgmlDType::Q8K,
387 30 => GgmlDType::BF16,
389 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
390 };
391
392 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
393
394 let mut dims = Vec::with_capacity(n_dims);
395 for _ in 0..n_dims {
396 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
397 }
398
399 let mut tensor_data = vec![0; data_len];
400 buffer.read_exact(&mut tensor_data)?;
401
402 let _acquired_load_guard = guard.acquire(device);
403 let b = if has_bias {
405 Some(deserialize_tensor(&mut buffer, device)?)
406 } else {
407 None
408 };
409
410 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
411 Ok((
412 Arc::new(Self {
413 w: QMatMul::QTensor(w.into()),
414 b: None,
415 }),
416 b,
417 ))
418 }
419}
420
421impl GgufMatMul {
422 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
423 let mut buffer = Cursor::new(data);
424
425 let version = buffer.read_u32::<LittleEndian>()?;
426 if let Err(e) = version_is_compatible(version) {
427 return Err(candle_core::Error::wrap(e));
428 }
429
430 let isq_type = buffer.read_u8()? as usize;
431 if isq_type != QuantizedSerdeType::Gguf as usize {
432 candle_core::bail!(
433 "ISQ type ({isq_type}) doesn't match expected type {}",
434 QuantizedSerdeType::Gguf as usize
435 );
436 }
437
438 let _ = buffer.read_u32::<LittleEndian>()? as usize;
439
440 let _ = buffer.read_u8()? != 0;
441
442 let dtype = buffer.read_u32::<LittleEndian>()?;
443 let dtype = match dtype {
444 0 => GgmlDType::F32,
445 1 => GgmlDType::F16,
446 2 => GgmlDType::Q4_0,
447 3 => GgmlDType::Q4_1,
448 6 => GgmlDType::Q5_0,
449 7 => GgmlDType::Q5_1,
450 8 => GgmlDType::Q8_0,
451 9 => GgmlDType::Q8_1,
452 10 => GgmlDType::Q2K,
453 11 => GgmlDType::Q3K,
454 12 => GgmlDType::Q4K,
455 13 => GgmlDType::Q5K,
456 14 => GgmlDType::Q6K,
457 15 => GgmlDType::Q8K,
458 30 => GgmlDType::BF16,
460 _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"),
461 };
462
463 IsqType::try_from(dtype)
464 }
465}