mistralrs_core/xlora_models/
classifier.rs1use crate::layers::{linear, linear_no_bias};
2use candle_core::{DType, Device, Result, Tensor, D};
3use candle_nn::{activation, ops::softmax_last_dim, Dropout, Linear, Module, ModuleT};
4use mistralrs_quant::ShardedVarBuilder;
5
6use crate::ops::{TopKLastDimOp, TopKOutput};
7
8use super::config::XLoraConfig;
9
10#[derive(Debug)]
11struct TemperatureScaledSoftmax {
12 temp: f64,
13}
14
15impl Module for TemperatureScaledSoftmax {
16 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
17 softmax_last_dim(&(xs / self.temp)?)
18 }
19}
20
21pub struct XLoraClassifier {
22 last: Linear,
23 inner: Vec<Box<dyn ModuleT + Send + Sync>>,
24 softmax: Option<TemperatureScaledSoftmax>,
25 scaling_pass_value: f64,
26 model_layers: usize,
27 n_classes: usize,
28 pub config: XLoraConfig,
29}
30
31impl XLoraClassifier {
32 pub fn new(
33 config: XLoraConfig,
34 n_layers: usize,
35 n_classes: usize,
36 vb: ShardedVarBuilder,
37 is_quantized: bool,
38 ) -> Result<Self> {
39 if config.enable_softmax_topk {
40 candle_core::bail!("`enable_softmax_topk` is not implemented");
41 }
42
43 let (last, inner): (Linear, Vec<Box<dyn ModuleT + Send + Sync>>) = if config.xlora_depth
44 == 1
45 {
46 let dim = if config.layerwise_scalings {
47 n_classes * n_layers
48 } else {
49 n_classes
50 };
51 assert!(vb.contains_tensor("last.weight"));
52 if config.use_bias {
53 assert!(vb.contains_tensor("last.bias"));
54 let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
55 (
56 if is_quantized {
57 Linear::new(
58 lin.weight().to_dtype(DType::F32)?,
59 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
60 )
61 } else {
62 lin
63 },
64 vec![],
65 )
66 } else {
67 let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
68 (
69 if is_quantized {
70 Linear::new(
71 lin.weight().to_dtype(DType::F32)?,
72 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
73 )
74 } else {
75 lin
76 },
77 vec![],
78 )
79 }
80 } else if config.xlora_depth == 2 {
81 let mut inner: Vec<Box<dyn ModuleT + Send + Sync>> = Vec::new();
82 assert!(vb.contains_tensor("inner.0.weight"));
83 if config.use_bias {
84 assert!(vb.contains_tensor("inner.0.bias"));
85 let lin = linear(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
86 inner.push(Box::new(if is_quantized {
87 Linear::new(
88 lin.weight().to_dtype(DType::F32)?,
89 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
90 )
91 } else {
92 lin
93 }));
94 } else {
95 let lin = linear_no_bias(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
96 inner.push(Box::new(if is_quantized {
97 Linear::new(
98 lin.weight().to_dtype(DType::F32)?,
99 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
100 )
101 } else {
102 lin
103 }));
104 }
105 if config.enable_relu_and_dropout {
106 inner.push(Box::new(activation::Activation::Relu));
107 inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
108 }
109 let dim = if config.layerwise_scalings {
110 n_classes * n_layers
111 } else {
112 n_classes
113 };
114 assert!(vb.contains_tensor("last.weight"));
115 if config.use_bias {
116 assert!(vb.contains_tensor("last.bias"));
117 let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
118 (
119 if is_quantized {
120 Linear::new(
121 lin.weight().to_dtype(DType::F32)?,
122 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
123 )
124 } else {
125 lin
126 },
127 inner,
128 )
129 } else {
130 let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
131 (
132 if is_quantized {
133 Linear::new(
134 lin.weight().to_dtype(DType::F32)?,
135 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
136 )
137 } else {
138 lin
139 },
140 inner,
141 )
142 }
143 } else {
144 let mut inner: Vec<Box<dyn ModuleT + Send + Sync>> = Vec::new();
145 assert!(vb.contains_tensor("inner.0.weight"));
146 if config.use_bias {
147 assert!(vb.contains_tensor("inner.0.bias"));
148 let lin = linear(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
149 inner.push(Box::new(if is_quantized {
150 Linear::new(
151 lin.weight().to_dtype(DType::F32)?,
152 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
153 )
154 } else {
155 lin
156 }));
157 } else {
158 let lin = linear_no_bias(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
159 inner.push(Box::new(if is_quantized {
160 Linear::new(
161 lin.weight().to_dtype(DType::F32)?,
162 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
163 )
164 } else {
165 lin
166 }));
167 }
168 if config.enable_relu_and_dropout {
169 inner.push(Box::new(activation::Activation::Relu));
170 inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
171 }
172 for i in 1..=config.xlora_depth - 2 {
173 assert!(vb.contains_tensor(&format!("inner.{i}.weight")));
174 if config.use_bias {
175 assert!(vb.contains_tensor(&format!("inner.{i}.bias")));
176 let lin = linear(
177 config.xlora_size,
178 config.xlora_size,
179 vb.pp(format!("inner.{i}")),
180 )?;
181 inner.push(Box::new(Linear::new(
182 lin.weight().to_dtype(DType::F32)?,
183 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
184 )));
185 } else {
186 let lin = linear_no_bias(
187 config.xlora_size,
188 config.xlora_size,
189 vb.pp(format!("inner.{i}")),
190 )?;
191 inner.push(Box::new(Linear::new(
192 lin.weight().to_dtype(DType::F32)?,
193 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
194 )));
195 }
196 if config.enable_relu_and_dropout {
197 inner.push(Box::new(activation::Activation::Relu));
198 inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
199 }
200 }
201 let dim = if config.layerwise_scalings {
202 n_classes * n_layers
203 } else {
204 n_classes
205 };
206 assert!(vb.contains_tensor("last.weight"));
207 if config.use_bias {
208 assert!(vb.contains_tensor("last.bias"));
209 let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
210 (
211 if is_quantized {
212 Linear::new(
213 lin.weight().to_dtype(DType::F32)?,
214 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
215 )
216 } else {
217 lin
218 },
219 inner,
220 )
221 } else {
222 let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
223 (
224 if is_quantized {
225 Linear::new(
226 lin.weight().to_dtype(DType::F32)?,
227 lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
228 )
229 } else {
230 lin
231 },
232 inner,
233 )
234 }
235 };
236 let last = if is_quantized {
237 Linear::new(
238 last.weight().to_dtype(DType::F32)?,
239 last.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
240 )
241 } else {
242 last
243 };
244 Ok(Self {
245 last,
246 inner,
247 softmax: if config.enable_softmax {
248 Some(TemperatureScaledSoftmax {
249 temp: config.softmax_temperature,
250 })
251 } else {
252 None
253 },
254 scaling_pass_value: config.scaling_pass_value,
255 model_layers: n_layers,
256 n_classes,
257 config,
258 })
259 }
260
261 pub fn forward(&self, mut hidden_states: Tensor) -> Result<Tensor> {
262 for layer in &self.inner {
263 hidden_states = layer.forward_t(&hidden_states, true)?;
264 }
265 let mut logits = self.last.forward(&hidden_states)?;
266
267 if !self.config.layerwise_scalings {
268 logits = logits.unsqueeze(2)?;
269 logits = logits.expand((
270 logits.dims()[0],
271 logits.dims()[1],
272 self.model_layers,
273 logits.dims()[3],
274 ))?;
275 }
276
277 let mut scalings = logits.reshape((
278 logits.dims()[0],
279 logits.dims()[1],
280 self.model_layers,
281 self.n_classes,
282 ))?;
283 if let Some(ref softmax) = self.softmax {
284 scalings = softmax.forward(&scalings)?;
285 }
286
287 let scalings = if let Some(topk_lora) = self.config.top_k_lora {
288 let TopKOutput { values: _, indices } = scalings.topk(topk_lora)?;
289
290 let scalings_zeroed = scalings.zeros_like()?;
291 scalings_zeroed.scatter_add(
292 &indices,
293 &scalings.gather(&indices, D::Minus1)?,
294 D::Minus1,
295 )?
296 } else {
297 scalings
298 };
299
300 Ok(scalings)
301 }
302
303 pub fn get_dummy_scalings(
304 &self,
305 bs: usize,
306 seq_len: usize,
307 device: &Device,
308 dtype: DType,
309 ) -> Result<Tensor> {
310 Tensor::full(
311 self.scaling_pass_value,
312 (bs, seq_len, self.model_layers, self.n_classes),
313 device,
314 )?
315 .to_dtype(dtype)
316 }
317
318 pub fn get_global_scaling_weight(&self) -> f64 {
319 self.config.global_scaling_weight
320 }
321}