mistralrs_core/utils/
unvarbuilder.rs1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use candle_core::{quantized::QMatMul, Tensor};
7use candle_nn::{Conv2d, Embedding, LayerNorm, Linear};
8use itertools::Itertools;
9use mistralrs_quant::QuantMethod;
10
11use crate::layers::{F32RmsNorm, QLinear, RmsNorm, ScaledEmbedding};
12
13pub trait ToTensors {
14 fn to_tensors(&self) -> HashMap<String, Tensor>;
16}
17
18impl ToTensors for Embedding {
19 fn to_tensors(&self) -> HashMap<String, Tensor> {
20 HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
21 }
22}
23
24impl ToTensors for ScaledEmbedding {
25 fn to_tensors(&self) -> HashMap<String, Tensor> {
26 HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
27 }
28}
29
30impl ToTensors for RmsNorm {
31 fn to_tensors(&self) -> HashMap<String, Tensor> {
32 HashMap::from_iter([("weight".to_string(), self.weight().clone())])
33 }
34}
35
36impl ToTensors for F32RmsNorm {
37 fn to_tensors(&self) -> HashMap<String, Tensor> {
38 HashMap::from_iter([("weight".to_string(), self.weight().clone())])
39 }
40}
41
42impl ToTensors for LayerNorm {
43 fn to_tensors(&self) -> HashMap<String, Tensor> {
44 let mut map = HashMap::new();
45 map.insert("weight".to_string(), self.weight().clone());
46 if let Some(bias) = self.bias() {
47 map.insert("bias".to_string(), bias.clone());
48 }
49 map
50 }
51}
52
53impl ToTensors for Linear {
54 fn to_tensors(&self) -> HashMap<String, Tensor> {
55 let mut map = HashMap::new();
56 map.insert("weight".to_string(), self.weight().clone());
57 if let Some(bias) = self.bias() {
58 map.insert("bias".to_string(), bias.clone());
59 }
60 map
61 }
62}
63
64impl ToTensors for Conv2d {
65 fn to_tensors(&self) -> HashMap<String, Tensor> {
66 let mut map = HashMap::new();
67 map.insert("weight".to_string(), self.weight().clone());
68 if let Some(bias) = self.bias() {
69 map.insert("bias".to_string(), bias.clone());
70 }
71 map
72 }
73}
74
75impl ToTensors for candle_nn::Conv1d {
76 fn to_tensors(&self) -> HashMap<String, Tensor> {
77 let mut map = HashMap::new();
78 map.insert("weight".to_string(), self.weight().clone());
79 if let Some(bias) = self.bias() {
80 map.insert("bias".to_string(), bias.clone());
81 }
82 map
83 }
84}
85
86impl ToTensors for QLinear {
87 fn to_tensors(&self) -> HashMap<String, Tensor> {
88 let mut map = HashMap::new();
89 match self.inner_ref() {
90 QMatMul::Tensor(w) | QMatMul::TensorF16(w) => {
91 map.insert("weight".to_string(), w.clone());
92 if let Some(bias) = self.bias() {
93 map.insert("bias".to_string(), bias.clone());
94 }
95 }
96 QMatMul::QTensor(_) => return HashMap::new(),
97 }
98 map
99 }
100}
101
102impl ToTensors for Arc<dyn QuantMethod> {
103 fn to_tensors(&self) -> HashMap<String, Tensor> {
104 let (w, b) = match self.unquant_weight_bias() {
105 Some(x) => x,
106 None => return HashMap::new(),
107 };
108 let mut map = HashMap::new();
109 map.insert("weight".to_string(), w);
110 if let Some(bias) = b {
111 map.insert("bias".to_string(), bias.clone());
112 }
113 map
114 }
115}
116
117pub struct UnVarBuilder {
118 data: Arc<RwLock<HashMap<String, Tensor>>>,
119 path: Vec<String>,
120}
121
122impl UnVarBuilder {
123 pub fn new() -> Self {
124 Self {
125 data: Arc::new(RwLock::new(HashMap::new())),
126 path: Vec::new(),
127 }
128 }
129
130 pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
131 let mut path = self.path.clone();
132 path.push(s.to_string());
133 Self {
134 data: self.data.clone(),
135 path,
136 }
137 }
138
139 pub fn pp<S: ToString>(&self, s: S) -> Self {
140 self.push_prefix(s)
141 }
142
143 pub fn path(&self) -> String {
144 self.path.iter().filter(|p| !p.trim().is_empty()).join(".")
145 }
146
147 pub fn add<T: ToTensors>(&self, item: &T) {
148 let mut data = self.data.write().expect("Write failed!");
149 let path = self.path();
150 data.extend(
151 item.to_tensors()
152 .into_iter()
153 .map(|(n, t)| (format!("{path}.{n}"), t))
154 .collect::<Vec<(_, _)>>(),
155 );
156 }
157
158 pub fn add_tensor<S: ToString>(&self, s: S, v: Tensor) {
159 let mut data = self.data.write().expect("Write failed!");
160 let mut path = self.path.clone();
161 path.push(s.to_string());
162 data.insert(
163 path.into_iter().filter(|p| !p.trim().is_empty()).join("."),
164 v,
165 );
166 }
167
168 pub fn extend(&self, other: Vec<(String, Tensor)>) {
169 let mut data = self.data.write().expect("Write failed!");
170 let path = self.path();
171 data.extend(
172 other
173 .into_iter()
174 .map(|(n, t)| {
175 (
176 if path.is_empty() {
177 n
178 } else {
179 format!("{path}.{n}")
180 },
181 t,
182 )
183 })
184 .collect::<Vec<(_, _)>>(),
185 );
186 }
187
188 pub fn to_safetensors(&self) -> Vec<(String, Tensor)> {
189 let data = self.data.read().expect("Read failed!");
190 data.iter().map(|(p, t)| (p.clone(), t.clone())).collect()
191 }
192}