mistralrs_core/utils/
unvarbuilder.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use candle_core::{quantized::QMatMul, Tensor};
7use candle_nn::{Conv2d, Embedding, LayerNorm, Linear};
8use itertools::Itertools;
9use mistralrs_quant::QuantMethod;
10
11use crate::layers::{F32RmsNorm, QLinear, RmsNorm, ScaledEmbedding};
12
13pub trait ToTensors {
14    /// Tensor names to tensors
15    fn to_tensors(&self) -> HashMap<String, Tensor>;
16}
17
18impl ToTensors for Embedding {
19    fn to_tensors(&self) -> HashMap<String, Tensor> {
20        HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
21    }
22}
23
24impl ToTensors for ScaledEmbedding {
25    fn to_tensors(&self) -> HashMap<String, Tensor> {
26        HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
27    }
28}
29
30impl ToTensors for RmsNorm {
31    fn to_tensors(&self) -> HashMap<String, Tensor> {
32        HashMap::from_iter([("weight".to_string(), self.weight().clone())])
33    }
34}
35
36impl ToTensors for F32RmsNorm {
37    fn to_tensors(&self) -> HashMap<String, Tensor> {
38        HashMap::from_iter([("weight".to_string(), self.weight().clone())])
39    }
40}
41
42impl ToTensors for LayerNorm {
43    fn to_tensors(&self) -> HashMap<String, Tensor> {
44        let mut map = HashMap::new();
45        map.insert("weight".to_string(), self.weight().clone());
46        if let Some(bias) = self.bias() {
47            map.insert("bias".to_string(), bias.clone());
48        }
49        map
50    }
51}
52
53impl ToTensors for Linear {
54    fn to_tensors(&self) -> HashMap<String, Tensor> {
55        let mut map = HashMap::new();
56        map.insert("weight".to_string(), self.weight().clone());
57        if let Some(bias) = self.bias() {
58            map.insert("bias".to_string(), bias.clone());
59        }
60        map
61    }
62}
63
64impl ToTensors for Conv2d {
65    fn to_tensors(&self) -> HashMap<String, Tensor> {
66        let mut map = HashMap::new();
67        map.insert("weight".to_string(), self.weight().clone());
68        if let Some(bias) = self.bias() {
69            map.insert("bias".to_string(), bias.clone());
70        }
71        map
72    }
73}
74
75impl ToTensors for candle_nn::Conv1d {
76    fn to_tensors(&self) -> HashMap<String, Tensor> {
77        let mut map = HashMap::new();
78        map.insert("weight".to_string(), self.weight().clone());
79        if let Some(bias) = self.bias() {
80            map.insert("bias".to_string(), bias.clone());
81        }
82        map
83    }
84}
85
86impl ToTensors for QLinear {
87    fn to_tensors(&self) -> HashMap<String, Tensor> {
88        let mut map = HashMap::new();
89        match self.inner_ref() {
90            QMatMul::Tensor(w) | QMatMul::TensorF16(w) => {
91                map.insert("weight".to_string(), w.clone());
92                if let Some(bias) = self.bias() {
93                    map.insert("bias".to_string(), bias.clone());
94                }
95            }
96            QMatMul::QTensor(_) => return HashMap::new(),
97        }
98        map
99    }
100}
101
102impl ToTensors for Arc<dyn QuantMethod> {
103    fn to_tensors(&self) -> HashMap<String, Tensor> {
104        let (w, b) = match self.unquant_weight_bias() {
105            Some(x) => x,
106            None => return HashMap::new(),
107        };
108        let mut map = HashMap::new();
109        map.insert("weight".to_string(), w);
110        if let Some(bias) = b {
111            map.insert("bias".to_string(), bias.clone());
112        }
113        map
114    }
115}
116
117pub struct UnVarBuilder {
118    data: Arc<RwLock<HashMap<String, Tensor>>>,
119    path: Vec<String>,
120}
121
122impl UnVarBuilder {
123    pub fn new() -> Self {
124        Self {
125            data: Arc::new(RwLock::new(HashMap::new())),
126            path: Vec::new(),
127        }
128    }
129
130    pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
131        let mut path = self.path.clone();
132        path.push(s.to_string());
133        Self {
134            data: self.data.clone(),
135            path,
136        }
137    }
138
139    pub fn pp<S: ToString>(&self, s: S) -> Self {
140        self.push_prefix(s)
141    }
142
143    pub fn path(&self) -> String {
144        self.path.iter().filter(|p| !p.trim().is_empty()).join(".")
145    }
146
147    pub fn add<T: ToTensors>(&self, item: &T) {
148        let mut data = self.data.write().expect("Write failed!");
149        let path = self.path();
150        data.extend(
151            item.to_tensors()
152                .into_iter()
153                .map(|(n, t)| (format!("{path}.{n}"), t))
154                .collect::<Vec<(_, _)>>(),
155        );
156    }
157
158    pub fn add_tensor<S: ToString>(&self, s: S, v: Tensor) {
159        let mut data = self.data.write().expect("Write failed!");
160        let mut path = self.path.clone();
161        path.push(s.to_string());
162        data.insert(
163            path.into_iter().filter(|p| !p.trim().is_empty()).join("."),
164            v,
165        );
166    }
167
168    pub fn extend(&self, other: Vec<(String, Tensor)>) {
169        let mut data = self.data.write().expect("Write failed!");
170        let path = self.path();
171        data.extend(
172            other
173                .into_iter()
174                .map(|(n, t)| {
175                    (
176                        if path.is_empty() {
177                            n
178                        } else {
179                            format!("{path}.{n}")
180                        },
181                        t,
182                    )
183                })
184                .collect::<Vec<(_, _)>>(),
185        );
186    }
187
188    pub fn to_safetensors(&self) -> Vec<(String, Tensor)> {
189        let data = self.data.read().expect("Read failed!");
190        data.iter().map(|(p, t)| (p.clone(), t.clone())).collect()
191    }
192}