mistralrs_core/utils/
unvarbuilder.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use candle_core::{quantized::QMatMul, Tensor};
7use candle_nn::{Conv2d, Embedding, LayerNorm, Linear};
8use itertools::Itertools;
9use mistralrs_quant::QuantMethod;
10
11use crate::layers::{F32RmsNorm, QLinear, RmsNorm, ScaledEmbedding};
12
13pub trait ToTensors {
14    /// Tensor names to tensors
15    fn to_tensors(&self) -> HashMap<String, Tensor>;
16}
17
18impl ToTensors for Embedding {
19    fn to_tensors(&self) -> HashMap<String, Tensor> {
20        HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
21    }
22}
23
24impl ToTensors for ScaledEmbedding {
25    fn to_tensors(&self) -> HashMap<String, Tensor> {
26        HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
27    }
28}
29
30impl ToTensors for RmsNorm {
31    fn to_tensors(&self) -> HashMap<String, Tensor> {
32        HashMap::from_iter([("weight".to_string(), self.weight().clone())])
33    }
34}
35
36impl ToTensors for F32RmsNorm {
37    fn to_tensors(&self) -> HashMap<String, Tensor> {
38        HashMap::from_iter([("weight".to_string(), self.weight().clone())])
39    }
40}
41
42impl ToTensors for LayerNorm {
43    fn to_tensors(&self) -> HashMap<String, Tensor> {
44        let mut map = HashMap::new();
45        map.insert("weight".to_string(), self.weight().clone());
46        if let Some(bias) = self.bias() {
47            map.insert("bias".to_string(), bias.clone());
48        }
49        map
50    }
51}
52
53impl ToTensors for Linear {
54    fn to_tensors(&self) -> HashMap<String, Tensor> {
55        let mut map = HashMap::new();
56        map.insert("weight".to_string(), self.weight().clone());
57        if let Some(bias) = self.bias() {
58            map.insert("bias".to_string(), bias.clone());
59        }
60        map
61    }
62}
63
64impl ToTensors for Conv2d {
65    fn to_tensors(&self) -> HashMap<String, Tensor> {
66        let mut map = HashMap::new();
67        map.insert("weight".to_string(), self.weight().clone());
68        if let Some(bias) = self.bias() {
69            map.insert("bias".to_string(), bias.clone());
70        }
71        map
72    }
73}
74
75impl ToTensors for QLinear {
76    fn to_tensors(&self) -> HashMap<String, Tensor> {
77        let mut map = HashMap::new();
78        match self.inner_ref() {
79            QMatMul::Tensor(w) | QMatMul::TensorF16(w) => {
80                map.insert("weight".to_string(), w.clone());
81                if let Some(bias) = self.bias() {
82                    map.insert("bias".to_string(), bias.clone());
83                }
84            }
85            QMatMul::QTensor(_) => return HashMap::new(),
86        }
87        map
88    }
89}
90
91impl ToTensors for Arc<dyn QuantMethod> {
92    fn to_tensors(&self) -> HashMap<String, Tensor> {
93        let (w, b) = match self.unquant_weight_bias() {
94            Some(x) => x,
95            None => return HashMap::new(),
96        };
97        let mut map = HashMap::new();
98        map.insert("weight".to_string(), w);
99        if let Some(bias) = b {
100            map.insert("bias".to_string(), bias.clone());
101        }
102        map
103    }
104}
105
106pub struct UnVarBuilder {
107    data: Arc<RwLock<HashMap<String, Tensor>>>,
108    path: Vec<String>,
109}
110
111impl UnVarBuilder {
112    pub fn new() -> Self {
113        Self {
114            data: Arc::new(RwLock::new(HashMap::new())),
115            path: Vec::new(),
116        }
117    }
118
119    pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
120        let mut path = self.path.clone();
121        path.push(s.to_string());
122        Self {
123            data: self.data.clone(),
124            path,
125        }
126    }
127
128    pub fn pp<S: ToString>(&self, s: S) -> Self {
129        self.push_prefix(s)
130    }
131
132    pub fn path(&self) -> String {
133        self.path.iter().filter(|p| !p.trim().is_empty()).join(".")
134    }
135
136    pub fn add<T: ToTensors>(&self, item: &T) {
137        let mut data = self.data.write().expect("Write failed!");
138        let path = self.path();
139        data.extend(
140            item.to_tensors()
141                .into_iter()
142                .map(|(n, t)| (format!("{path}.{n}"), t))
143                .collect::<Vec<(_, _)>>(),
144        );
145    }
146
147    pub fn add_tensor<S: ToString>(&self, s: S, v: Tensor) {
148        let mut data = self.data.write().expect("Write failed!");
149        let mut path = self.path.clone();
150        path.push(s.to_string());
151        data.insert(
152            path.into_iter().filter(|p| !p.trim().is_empty()).join("."),
153            v,
154        );
155    }
156
157    pub fn extend(&self, other: Vec<(String, Tensor)>) {
158        let mut data = self.data.write().expect("Write failed!");
159        let path = self.path();
160        data.extend(
161            other
162                .into_iter()
163                .map(|(n, t)| {
164                    (
165                        if path.is_empty() {
166                            n
167                        } else {
168                            format!("{path}.{n}")
169                        },
170                        t,
171                    )
172                })
173                .collect::<Vec<(_, _)>>(),
174        );
175    }
176
177    pub fn to_safetensors(&self) -> Vec<(String, Tensor)> {
178        let data = self.data.read().expect("Read failed!");
179        data.iter().map(|(p, t)| (p.clone(), t.clone())).collect()
180    }
181}