mistralrs_quant/
imatrix.rs

1use std::{
2    collections::HashMap,
3    fs,
4    io::Cursor,
5    path::Path,
6    sync::{Arc, RwLock},
7};
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10use candle_core::{Context, DType, Device, Result, Tensor, D};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug)]
14struct ImatrixLayerStats_ {
15    row_counts: usize,
16    ncalls: usize,
17    row_accum: Tensor,
18}
19
20#[derive(Debug, Clone)]
21pub struct ImatrixLayerStats(Arc<RwLock<Option<ImatrixLayerStats_>>>);
22
23impl ImatrixLayerStats {
24    pub fn new(w: &Tensor, device: &Device) -> Result<Self> {
25        Ok(Self(Arc::new(RwLock::new(Some(ImatrixLayerStats_ {
26            row_counts: 0,
27            ncalls: 0,
28            row_accum: Tensor::zeros((w.dim(1)?,), DType::F32, device)?,
29        })))))
30    }
31
32    pub fn process(&self, inp: &Tensor) -> Result<()> {
33        let mut handle = self.0.write().unwrap();
34        let this = handle.as_mut().context("Layer stats were dinitialized!")?;
35
36        let inp = inp.reshape(((), inp.dim(D::Minus1)?))?;
37        this.ncalls += 1;
38        this.row_counts += inp.dim(D::Minus1)?;
39        this.row_accum = (&this.row_accum + inp.to_dtype(DType::F32)?.sqr()?.sum(0)?)?;
40        Ok(())
41    }
42
43    pub fn compute_imatrix(&self) -> Result<Tensor> {
44        let handle = self.0.read().unwrap();
45        let this = handle.as_ref().context("Layer stats were dinitialized!")?;
46        (&this.row_accum / this.row_counts as f64)? * this.ncalls as f64
47    }
48
49    pub fn clear(&self) -> Result<()> {
50        let mut handle = self.0.write().unwrap();
51        *handle = None;
52        Ok(())
53    }
54}
55
56#[derive(Serialize, Deserialize)]
57pub struct CollectedImatrixData(pub HashMap<usize, Option<Vec<f32>>>);
58
59impl CollectedImatrixData {
60    pub fn save_imatrix<P: AsRef<Path>>(&self, fname: P) -> Result<()> {
61        if let Some(ext) = fname.as_ref().extension() {
62            if ext != "cimatrix" {
63                candle_core::bail!(
64                    "Expected a .cimatrix file to save collectd imatrix data to, got {:?}",
65                    ext
66                );
67            }
68        }
69        let mut buf: Vec<u8> = Vec::new();
70        let mut cursor = Cursor::new(&mut buf);
71
72        // Number of entries
73        cursor.write_u64::<LittleEndian>(self.0.len() as u64)?;
74
75        for (i, data) in &self.0 {
76            // i
77            cursor.write_u64::<LittleEndian>(*i as u64)?;
78            // has data
79            cursor.write_u8(data.is_some() as u8)?;
80            if let Some(data) = data {
81                // data len
82                cursor.write_u64::<LittleEndian>(data.len() as u64)?;
83                // data
84                for x in data {
85                    cursor.write_f32::<LittleEndian>(*x)?;
86                }
87            }
88        }
89
90        fs::write(fname, buf)?;
91        Ok(())
92    }
93
94    pub fn load_imatrix<P: AsRef<Path>>(fname: P) -> Result<Self> {
95        let buf = fs::read(fname)?;
96        let mut cursor = Cursor::new(buf);
97
98        let mut entries = HashMap::new();
99        let num_entries = cursor.read_u64::<LittleEndian>()?;
100
101        for _ in 0..num_entries {
102            let i = cursor.read_u64::<LittleEndian>()?;
103            let has_data = cursor.read_u8()? != 0;
104            if has_data {
105                let len_data = cursor.read_u64::<LittleEndian>()?;
106                let mut data = Vec::new();
107                for _ in 0..len_data {
108                    data.push(cursor.read_f32::<LittleEndian>()?);
109                }
110                entries.insert(i as usize, Some(data));
111            } else {
112                entries.insert(i as usize, None);
113            }
114        }
115
116        Ok(Self(entries))
117    }
118}