mistralrs_core/utils/
unvarbuilder.rs1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use candle_core::{quantized::QMatMul, Tensor};
7use candle_nn::{Conv2d, Embedding, LayerNorm, Linear};
8use itertools::Itertools;
9use mistralrs_quant::QuantMethod;
10
11use crate::layers::{F32RmsNorm, QLinear, RmsNorm, ScaledEmbedding};
12
13pub trait ToTensors {
14 fn to_tensors(&self) -> HashMap<String, Tensor>;
16}
17
18impl ToTensors for Embedding {
19 fn to_tensors(&self) -> HashMap<String, Tensor> {
20 HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
21 }
22}
23
24impl ToTensors for ScaledEmbedding {
25 fn to_tensors(&self) -> HashMap<String, Tensor> {
26 HashMap::from_iter([("weight".to_string(), self.embeddings().clone())])
27 }
28}
29
30impl ToTensors for RmsNorm {
31 fn to_tensors(&self) -> HashMap<String, Tensor> {
32 HashMap::from_iter([("weight".to_string(), self.weight().clone())])
33 }
34}
35
36impl ToTensors for F32RmsNorm {
37 fn to_tensors(&self) -> HashMap<String, Tensor> {
38 HashMap::from_iter([("weight".to_string(), self.weight().clone())])
39 }
40}
41
42impl ToTensors for LayerNorm {
43 fn to_tensors(&self) -> HashMap<String, Tensor> {
44 let mut map = HashMap::new();
45 map.insert("weight".to_string(), self.weight().clone());
46 if let Some(bias) = self.bias() {
47 map.insert("bias".to_string(), bias.clone());
48 }
49 map
50 }
51}
52
53impl ToTensors for Linear {
54 fn to_tensors(&self) -> HashMap<String, Tensor> {
55 let mut map = HashMap::new();
56 map.insert("weight".to_string(), self.weight().clone());
57 if let Some(bias) = self.bias() {
58 map.insert("bias".to_string(), bias.clone());
59 }
60 map
61 }
62}
63
64impl ToTensors for Conv2d {
65 fn to_tensors(&self) -> HashMap<String, Tensor> {
66 let mut map = HashMap::new();
67 map.insert("weight".to_string(), self.weight().clone());
68 if let Some(bias) = self.bias() {
69 map.insert("bias".to_string(), bias.clone());
70 }
71 map
72 }
73}
74
75impl ToTensors for QLinear {
76 fn to_tensors(&self) -> HashMap<String, Tensor> {
77 let mut map = HashMap::new();
78 match self.inner_ref() {
79 QMatMul::Tensor(w) | QMatMul::TensorF16(w) => {
80 map.insert("weight".to_string(), w.clone());
81 if let Some(bias) = self.bias() {
82 map.insert("bias".to_string(), bias.clone());
83 }
84 }
85 QMatMul::QTensor(_) => return HashMap::new(),
86 }
87 map
88 }
89}
90
91impl ToTensors for Arc<dyn QuantMethod> {
92 fn to_tensors(&self) -> HashMap<String, Tensor> {
93 let (w, b) = match self.unquant_weight_bias() {
94 Some(x) => x,
95 None => return HashMap::new(),
96 };
97 let mut map = HashMap::new();
98 map.insert("weight".to_string(), w);
99 if let Some(bias) = b {
100 map.insert("bias".to_string(), bias.clone());
101 }
102 map
103 }
104}
105
106pub struct UnVarBuilder {
107 data: Arc<RwLock<HashMap<String, Tensor>>>,
108 path: Vec<String>,
109}
110
111impl UnVarBuilder {
112 pub fn new() -> Self {
113 Self {
114 data: Arc::new(RwLock::new(HashMap::new())),
115 path: Vec::new(),
116 }
117 }
118
119 pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
120 let mut path = self.path.clone();
121 path.push(s.to_string());
122 Self {
123 data: self.data.clone(),
124 path,
125 }
126 }
127
128 pub fn pp<S: ToString>(&self, s: S) -> Self {
129 self.push_prefix(s)
130 }
131
132 pub fn path(&self) -> String {
133 self.path.iter().filter(|p| !p.trim().is_empty()).join(".")
134 }
135
136 pub fn add<T: ToTensors>(&self, item: &T) {
137 let mut data = self.data.write().expect("Write failed!");
138 let path = self.path();
139 data.extend(
140 item.to_tensors()
141 .into_iter()
142 .map(|(n, t)| (format!("{path}.{n}"), t))
143 .collect::<Vec<(_, _)>>(),
144 );
145 }
146
147 pub fn add_tensor<S: ToString>(&self, s: S, v: Tensor) {
148 let mut data = self.data.write().expect("Write failed!");
149 let mut path = self.path.clone();
150 path.push(s.to_string());
151 data.insert(
152 path.into_iter().filter(|p| !p.trim().is_empty()).join("."),
153 v,
154 );
155 }
156
157 pub fn extend(&self, other: Vec<(String, Tensor)>) {
158 let mut data = self.data.write().expect("Write failed!");
159 let path = self.path();
160 data.extend(
161 other
162 .into_iter()
163 .map(|(n, t)| {
164 (
165 if path.is_empty() {
166 n
167 } else {
168 format!("{path}.{n}")
169 },
170 t,
171 )
172 })
173 .collect::<Vec<(_, _)>>(),
174 );
175 }
176
177 pub fn to_safetensors(&self) -> Vec<(String, Tensor)> {
178 let data = self.data.read().expect("Read failed!");
179 data.iter().map(|(p, t)| (p.clone(), t.clone())).collect()
180 }
181}