mistralrs_quant/unquantized/
mod.rs1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. }
41 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42 QuantMethodConfig::Unquantized(l) => Ok(Self {
43 w: l.weight().clone(),
44 b: l.bias().cloned(),
45 stats: None,
46 }),
47 }
48 }
49
50 fn dequantize_w(&self) -> Result<Tensor> {
51 Ok(self.w.clone())
52 }
53
54 fn forward(&self, a: &Tensor) -> Result<Tensor> {
55 maybe_init_cublas_lt_wrapper(a.device().clone());
57
58 let w = match *a.dims() {
59 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
60 [bsize, _, _] => self.w.broadcast_left(bsize)?,
61 _ => self.w.clone(),
62 };
63
64 if let Some(stats) = &self.stats {
65 stats.process(a)?;
66 }
67
68 if let Some(b) = self.b.as_ref() {
69 let mut tgt_shape = a.dims().to_vec();
70 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
71 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
72
73 match a.device().location() {
74 DeviceLocation::Cuda { .. } => {
75 if let (Device::Cuda(_), Some(cublaslt)) =
77 (a.device(), CUBLASLT_CONTROLLER.get())
78 {
79 cublaslt
80 .batch_matmul(
81 a,
82 &w,
83 Some(&b.t()?.contiguous()?),
84 None,
85 Some(1.0),
86 None,
87 None,
88 )?
89 .t()
90 } else {
91 let matmul_result = a.matmul(&w.t()?)?;
92 matmul_result.broadcast_add(&b)
93 }
94 }
95 DeviceLocation::Metal { .. } => {
96 let matmul_result = a.matmul(&w.t()?)?;
97 matmul_result.broadcast_add(&b)
98 }
99 DeviceLocation::Cpu => {
100 #[cfg(feature = "accelerate")]
101 {
102 let original_dtype = a.dtype();
103 let a_f32 = a.to_dtype(DType::F32)?;
104 let w_f32 = w.t()?.to_dtype(DType::F32)?;
105 let b_f32 = b.to_dtype(DType::F32)?;
106 let matmul_result = a_f32.matmul(&w_f32)?;
107 matmul_result
108 .broadcast_add(&b_f32)?
109 .to_dtype(original_dtype)
110 }
111 #[cfg(not(feature = "accelerate"))]
112 {
113 let matmul_result = a.matmul(&w.t()?)?;
114 matmul_result.broadcast_add(&b)
115 }
116 }
117 }
118 } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
119 cublaslt
120 .batch_matmul(a, &w, None, None, None, None, None)?
121 .t()
122 } else {
123 MatMul.matmul(a, &w.t()?)
124 }
125 }
126
127 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
128 let w = self.w.index_select(indices, 0)?;
130
131 a.broadcast_matmul(&w.t()?)
132 }
133
134 fn quantized_act_type(&self) -> Option<DType> {
135 None
136 }
137
138 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139 Ok(Arc::new(Self {
140 w: (&self.w + delta)?,
141 b: self.b.clone(),
142 stats: self.stats.clone(),
143 }))
144 }
145
146 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
147 (self.w.dtype(), self.w.device().clone())
148 }
149
150 fn apply_isq(
151 self: Arc<Self>,
152 dtype: Option<IsqType>,
153 device: Device,
154 n_quantized: &AtomicUsize,
155 imatrix_weight: Option<Vec<f32>>,
156 guard: QuantizeOntoGuard,
157 ) -> Result<Arc<dyn QuantMethod>> {
158 match dtype {
159 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
161 let _acquired_quantize_guard = guard.acquire(&device);
162 if imatrix_weight.is_some() {
163 candle_core::bail!("HQQ does not support imatrix.");
165 }
166
167 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
168 let bits = match dtype.unwrap() {
169 IsqType::HQQ8 => HqqBits::Eight,
170 IsqType::HQQ4 => HqqBits::Four,
171 _ => unreachable!(),
175 };
176 let cfg = HqqConfig {
177 bits,
178 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
179 axis: HqqAxis::Zero,
180 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
181 round_zeros: false,
182 channel_wise: true,
183 };
184 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
185 if let Some(bias) = &self.b {
186 let bias = bias
187 .to_device(&device)?
188 .to_dtype(res.dtype_and_device().0)?;
189 Ok(Arc::new(res.with_bias(bias)))
190 } else {
191 Ok(Arc::new(res))
192 }
193 }
194 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
195 let _acquired_quantize_guard = guard.acquire(&device);
196 if imatrix_weight.is_some() {
197 candle_core::bail!("AFQ does not support imatrix.");
199 }
200
201 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202 let bits = match dtype.unwrap() {
203 IsqType::AFQ8 => AfqBits::Eight,
204 IsqType::AFQ6 => AfqBits::Six,
205 IsqType::AFQ4 => AfqBits::Four,
206 IsqType::AFQ3 => AfqBits::Three,
207 IsqType::AFQ2 => AfqBits::Two,
208 _ => unreachable!(),
209 };
210
211 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
212 weight: self.w.to_device(&device)?,
213 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
214 bits,
215 group_size: AfqGroupSize::default(),
216 })?))
217 }
218 Some(
219 IsqType::Q2K
220 | IsqType::Q3K
221 | IsqType::Q4K
222 | IsqType::Q4_0
223 | IsqType::Q4_1
224 | IsqType::Q5K
225 | IsqType::Q5_0
226 | IsqType::Q5_1
227 | IsqType::Q6K
228 | IsqType::Q8K
229 | IsqType::Q8_0
230 | IsqType::Q8_1,
231 ) => {
232 let dtype: GgmlDType = dtype.unwrap().try_into()?;
233 let res = if let Some(imatrix_weight) = imatrix_weight {
234 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
235 } else {
236 generate_isq!(self.w, device, dtype, n_quantized, guard)
237 };
238 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
239 q_weight: res,
240 b: self
241 .b
242 .as_ref()
243 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
244 })?))
245 }
246 Some(IsqType::F8E4M3) => {
247 let _acquired_quantize_guard = guard.acquire(&device);
248 if imatrix_weight.is_some() {
249 candle_core::bail!("F8E4M3 does not support imatrix.");
251 }
252
253 let w = self.w.to_device(&device)?;
254 let b = if let Some(b) = &self.b {
255 Some(b.to_device(&device)?)
256 } else {
257 None
258 };
259 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
260 lin: Linear::new(w, b),
261 dtype: DType::F8E4M3,
262 })?))
263 }
264 None => {
265 let _acquired_quantize_guard = guard.acquire(&device);
266 let w = self.w.to_device(&device)?;
269 let b = if let Some(b) = &self.b {
270 Some(b.to_device(&device)?)
271 } else {
272 None
273 };
274 Ok(Arc::new(UnquantLinear::new(
275 QuantMethodConfig::Unquantized(Linear::new(w, b)),
276 )?))
277 }
278 }
279 }
280
281 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
282 Some((self.w.clone(), self.b.clone()))
283 }
284
285 fn begin_track_stats(&mut self) -> Result<()> {
286 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
287 Ok(())
288 }
289
290 fn end_track_stats(&self) -> Result<Tensor> {
291 if let Some(stats) = &self.stats {
292 let imatrix = stats.compute_imatrix()?;
293 stats.clear()?;
294 Ok(imatrix)
295 } else {
296 candle_core::bail!("`{}` does not support tracking stats.", self.name())
297 }
298 }
299}
300
301impl QuantizedSerde for UnquantLinear {
316 fn isq_serde_supported(&self) -> bool {
317 true
318 }
319 fn name(&self) -> &'static str {
320 "unquant-linear"
321 }
322 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
323 self.serialize_with_bias(self.b.clone())
324 }
325 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
326 let mut buffer = Vec::new();
327
328 buffer.extend(&UQFF_VERSION.to_le_bytes());
331
332 buffer.push(QuantizedSerdeType::Unquant as u8);
334
335 buffer.push(bias.is_some() as u8);
337
338 serialize_tensor(&mut buffer, &self.w)?;
340
341 if let Some(bias) = &bias {
342 serialize_tensor(&mut buffer, bias)?;
344 }
345
346 Ok(Cow::from(buffer))
347 }
348
349 fn deserialize(
350 data: Cow<[u8]>,
351 device: &Device,
352 _comm: &Arc<crate::Comm>,
353 guard: QuantizeOntoGuard,
354 ) -> Result<Arc<dyn QuantMethod>>
355 where
356 Self: Sized,
357 {
358 let mut buffer = Cursor::new(data);
359
360 let version = buffer.read_u32::<LittleEndian>()?;
361 if let Err(e) = version_is_compatible(version) {
362 return Err(candle_core::Error::wrap(e));
363 }
364
365 let isq_type = buffer.read_u8()? as usize;
366 if isq_type != QuantizedSerdeType::Unquant as usize {
367 candle_core::bail!(
368 "ISQ type ({isq_type}) doesn't match expected type {}",
369 QuantizedSerdeType::Unquant as usize
370 );
371 }
372
373 let has_bias = buffer.read_u8()? != 0;
374
375 let _acquired_load_guard = guard.acquire(device);
376 let w = deserialize_tensor(&mut buffer, device)?;
377
378 let b = if has_bias {
379 Some(deserialize_tensor(&mut buffer, device)?)
380 } else {
381 None
382 };
383
384 Ok(Arc::new(Self { w, b, stats: None }))
385 }
386 fn deserialize_ext_bias(
387 data: Cow<[u8]>,
388 device: &Device,
389 guard: QuantizeOntoGuard,
390 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
391 where
392 Self: Sized,
393 {
394 let mut buffer = Cursor::new(data);
395
396 let version = buffer.read_u32::<LittleEndian>()?;
397 if let Err(e) = version_is_compatible(version) {
398 return Err(candle_core::Error::wrap(e));
399 }
400
401 let isq_type = buffer.read_u8()? as usize;
402 if isq_type != QuantizedSerdeType::Unquant as usize {
403 candle_core::bail!(
404 "ISQ type ({isq_type}) doesn't match expected type {}",
405 QuantizedSerdeType::Unquant as usize
406 );
407 }
408
409 let has_bias = buffer.read_u8()? != 0;
410
411 let _acquired_load_guard = guard.acquire(device);
412 let w = deserialize_tensor(&mut buffer, device)?;
413
414 let b = if has_bias {
415 Some(deserialize_tensor(&mut buffer, device)?)
416 } else {
417 None
418 };
419
420 Ok((
421 Arc::new(Self {
422 w,
423 b: None,
424 stats: None,
425 }),
426 b,
427 ))
428 }
429}