mistralrs_quant/unquantized/
mod.rs1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. } => unreachable!(),
41 QuantMethodConfig::Unquantized(l) => Ok(Self {
42 w: l.weight().clone(),
43 b: l.bias().cloned(),
44 stats: None,
45 }),
46 }
47 }
48
49 fn dequantize_w(&self) -> Result<Tensor> {
50 Ok(self.w.clone())
51 }
52
53 fn forward(&self, a: &Tensor) -> Result<Tensor> {
54 maybe_init_cublas_lt_wrapper(a.device().clone());
56
57 let w = match *a.dims() {
58 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59 [bsize, _, _] => self.w.broadcast_left(bsize)?,
60 _ => self.w.clone(),
61 };
62
63 if let Some(stats) = &self.stats {
64 stats.process(a)?;
65 }
66
67 if let Some(b) = self.b.as_ref() {
68 let mut tgt_shape = a.dims().to_vec();
69 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72 match a.device().location() {
73 DeviceLocation::Cuda { .. } => {
74 if let (Device::Cuda(_), Some(cublaslt)) =
76 (a.device(), CUBLASLT_CONTROLLER.get())
77 {
78 cublaslt
79 .batch_matmul(
80 a,
81 &w,
82 Some(&b.t()?.contiguous()?),
83 None,
84 Some(1.0),
85 None,
86 None,
87 )?
88 .t()
89 } else {
90 let matmul_result = a.matmul(&w.t()?)?;
91 matmul_result.broadcast_add(&b)
92 }
93 }
94 DeviceLocation::Metal { .. } => {
95 let matmul_result = a.matmul(&w.t()?)?;
96 matmul_result.broadcast_add(&b)
97 }
98 DeviceLocation::Cpu => {
99 #[cfg(feature = "accelerate")]
100 {
101 let original_dtype = a.dtype();
102 let a_f32 = a.to_dtype(DType::F32)?;
103 let w_f32 = w.t()?.to_dtype(DType::F32)?;
104 let b_f32 = b.to_dtype(DType::F32)?;
105 let matmul_result = a_f32.matmul(&w_f32)?;
106 matmul_result
107 .broadcast_add(&b_f32)?
108 .to_dtype(original_dtype)
109 }
110 #[cfg(not(feature = "accelerate"))]
111 {
112 let matmul_result = a.matmul(&w.t()?)?;
113 matmul_result.broadcast_add(&b)
114 }
115 }
116 }
117 } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
118 cublaslt
119 .batch_matmul(a, &w, None, None, None, None, None)?
120 .t()
121 } else {
122 MatMul.matmul(a, &w.t()?)
123 }
124 }
125
126 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
127 let w = self.w.index_select(indices, 0)?;
129
130 a.broadcast_matmul(&w.t()?)
131 }
132
133 fn quantized_act_type(&self) -> Option<DType> {
134 None
135 }
136
137 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
138 Ok(Arc::new(Self {
139 w: (&self.w + delta)?,
140 b: self.b.clone(),
141 stats: self.stats.clone(),
142 }))
143 }
144
145 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
146 (self.w.dtype(), self.w.device().clone())
147 }
148
149 fn apply_isq(
150 self: Arc<Self>,
151 dtype: Option<IsqType>,
152 device: Device,
153 n_quantized: &AtomicUsize,
154 imatrix_weight: Option<Vec<f32>>,
155 guard: QuantizeOntoGuard,
156 ) -> Result<Arc<dyn QuantMethod>> {
157 match dtype {
158 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
160 let _acquired_quantize_guard = guard.acquire(&device);
161 if imatrix_weight.is_some() {
162 candle_core::bail!("HQQ does not support imatrix.");
164 }
165
166 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
167 let bits = match dtype.unwrap() {
168 IsqType::HQQ8 => HqqBits::Eight,
169 IsqType::HQQ4 => HqqBits::Four,
170 _ => unreachable!(),
174 };
175 let cfg = HqqConfig {
176 bits,
177 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
178 axis: HqqAxis::Zero,
179 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
180 round_zeros: false,
181 channel_wise: true,
182 };
183 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
184 if let Some(bias) = &self.b {
185 let bias = bias
186 .to_device(&device)?
187 .to_dtype(res.dtype_and_device().0)?;
188 Ok(Arc::new(res.with_bias(bias)))
189 } else {
190 Ok(Arc::new(res))
191 }
192 }
193 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
194 let _acquired_quantize_guard = guard.acquire(&device);
195 if imatrix_weight.is_some() {
196 candle_core::bail!("AFQ does not support imatrix.");
198 }
199
200 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
201 let bits = match dtype.unwrap() {
202 IsqType::AFQ8 => AfqBits::Eight,
203 IsqType::AFQ6 => AfqBits::Six,
204 IsqType::AFQ4 => AfqBits::Four,
205 IsqType::AFQ3 => AfqBits::Three,
206 IsqType::AFQ2 => AfqBits::Two,
207 _ => unreachable!(),
208 };
209
210 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
211 weight: self.w.to_device(&device)?,
212 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
213 bits,
214 group_size: AfqGroupSize::default(),
215 })?))
216 }
217 Some(
218 IsqType::Q2K
219 | IsqType::Q3K
220 | IsqType::Q4K
221 | IsqType::Q4_0
222 | IsqType::Q4_1
223 | IsqType::Q5K
224 | IsqType::Q5_0
225 | IsqType::Q5_1
226 | IsqType::Q6K
227 | IsqType::Q8K
228 | IsqType::Q8_0
229 | IsqType::Q8_1,
230 ) => {
231 let dtype: GgmlDType = dtype.unwrap().try_into()?;
232 let res = if let Some(imatrix_weight) = imatrix_weight {
233 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
234 } else {
235 generate_isq!(self.w, device, dtype, n_quantized, guard)
236 };
237 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
238 q_weight: res,
239 b: self
240 .b
241 .as_ref()
242 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
243 })?))
244 }
245 Some(IsqType::F8E4M3) => {
246 let _acquired_quantize_guard = guard.acquire(&device);
247 if imatrix_weight.is_some() {
248 candle_core::bail!("F8E4M3 does not support imatrix.");
250 }
251
252 let w = self.w.to_device(&device)?;
253 let b = if let Some(b) = &self.b {
254 Some(b.to_device(&device)?)
255 } else {
256 None
257 };
258 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
259 lin: Linear::new(w, b),
260 dtype: DType::F8E4M3,
261 })?))
262 }
263 None => {
264 let _acquired_quantize_guard = guard.acquire(&device);
265 let w = self.w.to_device(&device)?;
268 let b = if let Some(b) = &self.b {
269 Some(b.to_device(&device)?)
270 } else {
271 None
272 };
273 Ok(Arc::new(UnquantLinear::new(
274 QuantMethodConfig::Unquantized(Linear::new(w, b)),
275 )?))
276 }
277 }
278 }
279
280 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
281 Some((self.w.clone(), self.b.clone()))
282 }
283
284 fn begin_track_stats(&mut self) -> Result<()> {
285 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
286 Ok(())
287 }
288
289 fn end_track_stats(&self) -> Result<Tensor> {
290 if let Some(stats) = &self.stats {
291 let imatrix = stats.compute_imatrix()?;
292 stats.clear()?;
293 Ok(imatrix)
294 } else {
295 candle_core::bail!("`{}` does not support tracking stats.", self.name())
296 }
297 }
298}
299
300impl QuantizedSerde for UnquantLinear {
315 fn isq_serde_supported(&self) -> bool {
316 true
317 }
318 fn name(&self) -> &'static str {
319 "unquant-linear"
320 }
321 fn serialize(&self) -> Result<Cow<[u8]>> {
322 self.serialize_with_bias(self.b.clone())
323 }
324 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
325 let mut buffer = Vec::new();
326
327 buffer.extend(&UQFF_VERSION.to_le_bytes());
330
331 buffer.push(QuantizedSerdeType::Unquant as u8);
333
334 buffer.push(bias.is_some() as u8);
336
337 serialize_tensor(&mut buffer, &self.w)?;
339
340 if let Some(bias) = &bias {
341 serialize_tensor(&mut buffer, bias)?;
343 }
344
345 Ok(Cow::from(buffer))
346 }
347
348 fn deserialize(
349 data: Cow<[u8]>,
350 device: &Device,
351 _comm: &Arc<crate::Comm>,
352 guard: QuantizeOntoGuard,
353 ) -> Result<Arc<dyn QuantMethod>>
354 where
355 Self: Sized,
356 {
357 let mut buffer = Cursor::new(data);
358
359 let version = buffer.read_u32::<LittleEndian>()?;
360 if let Err(e) = version_is_compatible(version) {
361 return Err(candle_core::Error::wrap(e));
362 }
363
364 let isq_type = buffer.read_u8()? as usize;
365 if isq_type != QuantizedSerdeType::Unquant as usize {
366 candle_core::bail!(
367 "ISQ type ({isq_type}) doesn't match expected type {}",
368 QuantizedSerdeType::Unquant as usize
369 );
370 }
371
372 let has_bias = buffer.read_u8()? != 0;
373
374 let _acquired_load_guard = guard.acquire(device);
375 let w = deserialize_tensor(&mut buffer, device)?;
376
377 let b = if has_bias {
378 Some(deserialize_tensor(&mut buffer, device)?)
379 } else {
380 None
381 };
382
383 Ok(Arc::new(Self { w, b, stats: None }))
384 }
385 fn deserialize_ext_bias(
386 data: Cow<[u8]>,
387 device: &Device,
388 guard: QuantizeOntoGuard,
389 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
390 where
391 Self: Sized,
392 {
393 let mut buffer = Cursor::new(data);
394
395 let version = buffer.read_u32::<LittleEndian>()?;
396 if let Err(e) = version_is_compatible(version) {
397 return Err(candle_core::Error::wrap(e));
398 }
399
400 let isq_type = buffer.read_u8()? as usize;
401 if isq_type != QuantizedSerdeType::Unquant as usize {
402 candle_core::bail!(
403 "ISQ type ({isq_type}) doesn't match expected type {}",
404 QuantizedSerdeType::Unquant as usize
405 );
406 }
407
408 let has_bias = buffer.read_u8()? != 0;
409
410 let _acquired_load_guard = guard.acquire(device);
411 let w = deserialize_tensor(&mut buffer, device)?;
412
413 let b = if has_bias {
414 Some(deserialize_tensor(&mut buffer, device)?)
415 } else {
416 None
417 };
418
419 Ok((
420 Arc::new(Self {
421 w,
422 b: None,
423 stats: None,
424 }),
425 b,
426 ))
427 }
428}