mistralrs_quant/
imatrix.rs1use std::{
2 collections::HashMap,
3 fs,
4 io::Cursor,
5 path::Path,
6 sync::{Arc, RwLock},
7};
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10use candle_core::{Context, DType, Device, Result, Tensor, D};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug)]
14struct ImatrixLayerStats_ {
15 row_counts: usize,
16 ncalls: usize,
17 row_accum: Tensor,
18}
19
20#[derive(Debug, Clone)]
21pub struct ImatrixLayerStats(Arc<RwLock<Option<ImatrixLayerStats_>>>);
22
23impl ImatrixLayerStats {
24 pub fn new(w: &Tensor, device: &Device) -> Result<Self> {
25 Ok(Self(Arc::new(RwLock::new(Some(ImatrixLayerStats_ {
26 row_counts: 0,
27 ncalls: 0,
28 row_accum: Tensor::zeros((w.dim(1)?,), DType::F32, device)?,
29 })))))
30 }
31
32 pub fn process(&self, inp: &Tensor) -> Result<()> {
33 let mut handle = self.0.write().unwrap();
34 let this = handle.as_mut().context("Layer stats were dinitialized!")?;
35
36 let inp = inp.reshape(((), inp.dim(D::Minus1)?))?;
37 this.ncalls += 1;
38 this.row_counts += inp.dim(D::Minus1)?;
39 this.row_accum = (&this.row_accum + inp.to_dtype(DType::F32)?.sqr()?.sum(0)?)?;
40 Ok(())
41 }
42
43 pub fn compute_imatrix(&self) -> Result<Tensor> {
44 let handle = self.0.read().unwrap();
45 let this = handle.as_ref().context("Layer stats were dinitialized!")?;
46 (&this.row_accum / this.row_counts as f64)? * this.ncalls as f64
47 }
48
49 pub fn clear(&self) -> Result<()> {
50 let mut handle = self.0.write().unwrap();
51 *handle = None;
52 Ok(())
53 }
54}
55
56#[derive(Serialize, Deserialize)]
57pub struct CollectedImatrixData(pub HashMap<usize, Option<Vec<f32>>>);
58
59impl CollectedImatrixData {
60 pub fn save_imatrix<P: AsRef<Path>>(&self, fname: P) -> Result<()> {
61 if let Some(ext) = fname.as_ref().extension() {
62 if ext != "cimatrix" {
63 candle_core::bail!(
64 "Expected a .cimatrix file to save collectd imatrix data to, got {:?}",
65 ext
66 );
67 }
68 }
69 let mut buf: Vec<u8> = Vec::new();
70 let mut cursor = Cursor::new(&mut buf);
71
72 cursor.write_u64::<LittleEndian>(self.0.len() as u64)?;
74
75 for (i, data) in &self.0 {
76 cursor.write_u64::<LittleEndian>(*i as u64)?;
78 cursor.write_u8(data.is_some() as u8)?;
80 if let Some(data) = data {
81 cursor.write_u64::<LittleEndian>(data.len() as u64)?;
83 for x in data {
85 cursor.write_f32::<LittleEndian>(*x)?;
86 }
87 }
88 }
89
90 fs::write(fname, buf)?;
91 Ok(())
92 }
93
94 pub fn load_imatrix<P: AsRef<Path>>(fname: P) -> Result<Self> {
95 let buf = fs::read(fname)?;
96 let mut cursor = Cursor::new(buf);
97
98 let mut entries = HashMap::new();
99 let num_entries = cursor.read_u64::<LittleEndian>()?;
100
101 for _ in 0..num_entries {
102 let i = cursor.read_u64::<LittleEndian>()?;
103 let has_data = cursor.read_u8()? != 0;
104 if has_data {
105 let len_data = cursor.read_u64::<LittleEndian>()?;
106 let mut data = Vec::new();
107 for _ in 0..len_data {
108 data.push(cursor.read_f32::<LittleEndian>()?);
109 }
110 entries.insert(i as usize, Some(data));
111 } else {
112 entries.insert(i as usize, None);
113 }
114 }
115
116 Ok(Self(entries))
117 }
118}