mistralrs_core/xlora_models/
classifier.rs

1use crate::layers::{linear, linear_no_bias};
2use candle_core::{DType, Device, Result, Tensor, D};
3use candle_nn::{activation, ops::softmax_last_dim, Dropout, Linear, Module, ModuleT};
4use mistralrs_quant::ShardedVarBuilder;
5
6use crate::ops::{TopKLastDimOp, TopKOutput};
7
8use super::config::XLoraConfig;
9
10#[derive(Debug)]
11struct TemperatureScaledSoftmax {
12    temp: f64,
13}
14
15impl Module for TemperatureScaledSoftmax {
16    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
17        softmax_last_dim(&(xs / self.temp)?)
18    }
19}
20
21pub struct XLoraClassifier {
22    last: Linear,
23    inner: Vec<Box<dyn ModuleT + Send + Sync>>,
24    softmax: Option<TemperatureScaledSoftmax>,
25    scaling_pass_value: f64,
26    model_layers: usize,
27    n_classes: usize,
28    pub config: XLoraConfig,
29}
30
31impl XLoraClassifier {
32    pub fn new(
33        config: XLoraConfig,
34        n_layers: usize,
35        n_classes: usize,
36        vb: ShardedVarBuilder,
37        is_quantized: bool,
38    ) -> Result<Self> {
39        if config.enable_softmax_topk {
40            candle_core::bail!("`enable_softmax_topk` is not implemented");
41        }
42
43        let (last, inner): (Linear, Vec<Box<dyn ModuleT + Send + Sync>>) = if config.xlora_depth
44            == 1
45        {
46            let dim = if config.layerwise_scalings {
47                n_classes * n_layers
48            } else {
49                n_classes
50            };
51            assert!(vb.contains_tensor("last.weight"));
52            if config.use_bias {
53                assert!(vb.contains_tensor("last.bias"));
54                let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
55                (
56                    if is_quantized {
57                        Linear::new(
58                            lin.weight().to_dtype(DType::F32)?,
59                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
60                        )
61                    } else {
62                        lin
63                    },
64                    vec![],
65                )
66            } else {
67                let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
68                (
69                    if is_quantized {
70                        Linear::new(
71                            lin.weight().to_dtype(DType::F32)?,
72                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
73                        )
74                    } else {
75                        lin
76                    },
77                    vec![],
78                )
79            }
80        } else if config.xlora_depth == 2 {
81            let mut inner: Vec<Box<dyn ModuleT + Send + Sync>> = Vec::new();
82            assert!(vb.contains_tensor("inner.0.weight"));
83            if config.use_bias {
84                assert!(vb.contains_tensor("inner.0.bias"));
85                let lin = linear(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
86                inner.push(Box::new(if is_quantized {
87                    Linear::new(
88                        lin.weight().to_dtype(DType::F32)?,
89                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
90                    )
91                } else {
92                    lin
93                }));
94            } else {
95                let lin = linear_no_bias(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
96                inner.push(Box::new(if is_quantized {
97                    Linear::new(
98                        lin.weight().to_dtype(DType::F32)?,
99                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
100                    )
101                } else {
102                    lin
103                }));
104            }
105            if config.enable_relu_and_dropout {
106                inner.push(Box::new(activation::Activation::Relu));
107                inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
108            }
109            let dim = if config.layerwise_scalings {
110                n_classes * n_layers
111            } else {
112                n_classes
113            };
114            assert!(vb.contains_tensor("last.weight"));
115            if config.use_bias {
116                assert!(vb.contains_tensor("last.bias"));
117                let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
118                (
119                    if is_quantized {
120                        Linear::new(
121                            lin.weight().to_dtype(DType::F32)?,
122                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
123                        )
124                    } else {
125                        lin
126                    },
127                    inner,
128                )
129            } else {
130                let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
131                (
132                    if is_quantized {
133                        Linear::new(
134                            lin.weight().to_dtype(DType::F32)?,
135                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
136                        )
137                    } else {
138                        lin
139                    },
140                    inner,
141                )
142            }
143        } else {
144            let mut inner: Vec<Box<dyn ModuleT + Send + Sync>> = Vec::new();
145            assert!(vb.contains_tensor("inner.0.weight"));
146            if config.use_bias {
147                assert!(vb.contains_tensor("inner.0.bias"));
148                let lin = linear(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
149                inner.push(Box::new(if is_quantized {
150                    Linear::new(
151                        lin.weight().to_dtype(DType::F32)?,
152                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
153                    )
154                } else {
155                    lin
156                }));
157            } else {
158                let lin = linear_no_bias(config.hidden_size, config.xlora_size, vb.pp("inner.0"))?;
159                inner.push(Box::new(if is_quantized {
160                    Linear::new(
161                        lin.weight().to_dtype(DType::F32)?,
162                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
163                    )
164                } else {
165                    lin
166                }));
167            }
168            if config.enable_relu_and_dropout {
169                inner.push(Box::new(activation::Activation::Relu));
170                inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
171            }
172            for i in 1..=config.xlora_depth - 2 {
173                assert!(vb.contains_tensor(&format!("inner.{i}.weight")));
174                if config.use_bias {
175                    assert!(vb.contains_tensor(&format!("inner.{i}.bias")));
176                    let lin = linear(
177                        config.xlora_size,
178                        config.xlora_size,
179                        vb.pp(format!("inner.{i}")),
180                    )?;
181                    inner.push(Box::new(Linear::new(
182                        lin.weight().to_dtype(DType::F32)?,
183                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
184                    )));
185                } else {
186                    let lin = linear_no_bias(
187                        config.xlora_size,
188                        config.xlora_size,
189                        vb.pp(format!("inner.{i}")),
190                    )?;
191                    inner.push(Box::new(Linear::new(
192                        lin.weight().to_dtype(DType::F32)?,
193                        lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
194                    )));
195                }
196                if config.enable_relu_and_dropout {
197                    inner.push(Box::new(activation::Activation::Relu));
198                    inner.push(Box::new(Dropout::new(config.xlora_dropout_p)));
199                }
200            }
201            let dim = if config.layerwise_scalings {
202                n_classes * n_layers
203            } else {
204                n_classes
205            };
206            assert!(vb.contains_tensor("last.weight"));
207            if config.use_bias {
208                assert!(vb.contains_tensor("last.bias"));
209                let lin = linear(config.hidden_size, dim, vb.pp("last"))?;
210                (
211                    if is_quantized {
212                        Linear::new(
213                            lin.weight().to_dtype(DType::F32)?,
214                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
215                        )
216                    } else {
217                        lin
218                    },
219                    inner,
220                )
221            } else {
222                let lin = linear_no_bias(config.hidden_size, dim, vb.pp("last"))?;
223                (
224                    if is_quantized {
225                        Linear::new(
226                            lin.weight().to_dtype(DType::F32)?,
227                            lin.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
228                        )
229                    } else {
230                        lin
231                    },
232                    inner,
233                )
234            }
235        };
236        let last = if is_quantized {
237            Linear::new(
238                last.weight().to_dtype(DType::F32)?,
239                last.bias().map(|x| x.to_dtype(DType::F32).unwrap()),
240            )
241        } else {
242            last
243        };
244        Ok(Self {
245            last,
246            inner,
247            softmax: if config.enable_softmax {
248                Some(TemperatureScaledSoftmax {
249                    temp: config.softmax_temperature,
250                })
251            } else {
252                None
253            },
254            scaling_pass_value: config.scaling_pass_value,
255            model_layers: n_layers,
256            n_classes,
257            config,
258        })
259    }
260
261    pub fn forward(&self, mut hidden_states: Tensor) -> Result<Tensor> {
262        for layer in &self.inner {
263            hidden_states = layer.forward_t(&hidden_states, true)?;
264        }
265        let mut logits = self.last.forward(&hidden_states)?;
266
267        if !self.config.layerwise_scalings {
268            logits = logits.unsqueeze(2)?;
269            logits = logits.expand((
270                logits.dims()[0],
271                logits.dims()[1],
272                self.model_layers,
273                logits.dims()[3],
274            ))?;
275        }
276
277        let mut scalings = logits.reshape((
278            logits.dims()[0],
279            logits.dims()[1],
280            self.model_layers,
281            self.n_classes,
282        ))?;
283        if let Some(ref softmax) = self.softmax {
284            scalings = softmax.forward(&scalings)?;
285        }
286
287        let scalings = if let Some(topk_lora) = self.config.top_k_lora {
288            let TopKOutput { values: _, indices } = scalings.topk(topk_lora)?;
289
290            let scalings_zeroed = scalings.zeros_like()?;
291            scalings_zeroed.scatter_add(
292                &indices,
293                &scalings.gather(&indices, D::Minus1)?,
294                D::Minus1,
295            )?
296        } else {
297            scalings
298        };
299
300        Ok(scalings)
301    }
302
303    pub fn get_dummy_scalings(
304        &self,
305        bs: usize,
306        seq_len: usize,
307        device: &Device,
308        dtype: DType,
309    ) -> Result<Tensor> {
310        Tensor::full(
311            self.scaling_pass_value,
312            (bs, seq_len, self.model_layers, self.n_classes),
313            device,
314        )?
315        .to_dtype(dtype)
316    }
317
318    pub fn get_global_scaling_weight(&self) -> f64 {
319        self.config.global_scaling_weight
320    }
321}