mistralrs_core/xlora_models/
mod.rs

1mod classifier;
2mod config;
3mod gemma;
4mod gemma2;
5mod llama;
6mod mistral;
7mod mixtral;
8mod phi2;
9mod phi3;
10mod quantized_llama;
11mod quantized_phi3;
12mod starcoder2;
13
14use std::sync::Arc;
15
16use crate::{
17    lora::Ordering,
18    pipeline::{text_models_inputs_processor::FlashParams, EitherCache},
19};
20use candle_core::{DType, Device, Result, Tensor};
21pub(crate) use config::XLoraConfig;
22pub(crate) use gemma::XLoraModel as XLoraGemma;
23pub(crate) use gemma2::Model as XLoraGemma2;
24pub(crate) use llama::XLoraLlama;
25pub(crate) use mistral::XLoraModel as XLoraMistral;
26pub(crate) use mixtral::XLoraModel as XLoraMixtral;
27pub(crate) use phi2::Model as XLoraPhi2;
28pub(crate) use phi3::Model as XLoraPhi3;
29pub(crate) use quantized_llama::ModelWeights as XLoraQLlama;
30pub(crate) use quantized_phi3::ModelWeights as XLoraQPhi3;
31pub(crate) use starcoder2::Model as XLoraStarcoder2;
32use tokio::sync::Mutex;
33
34use crate::{get_mut_arcmutex, pipeline::Cache};
35
36use self::classifier::XLoraClassifier;
37
38pub struct NonGranularState {
39    pub non_granular_index: Arc<Mutex<usize>>,
40    pub tgt_non_granular_index: usize,
41}
42
43trait ScalingsMaker {
44    fn get_classifier(&self) -> &XLoraClassifier;
45    /// For dummy scalings
46    fn dtype(&self) -> DType;
47    #[allow(clippy::too_many_arguments)]
48    fn forward(
49        &self,
50        input_ids: &Tensor,
51        seqlen_offsets: &[usize],
52        scalings: Tensor,
53        is_full_pass: bool,
54        no_kv_cache: bool,
55        is_scaling_pass: Option<f64>,
56        context_lens: &[usize],
57        flash_params: &FlashParams,
58    ) -> Result<Tensor>;
59    fn get_cache(&self) -> &EitherCache;
60
61    #[allow(clippy::too_many_arguments)]
62    fn get_scalings(
63        &self,
64        input_ids: &Tensor,
65        input_ids_full: &Tensor,
66        seqlen_offsets: &[usize],
67        seqlen_offsets_full: &[usize],
68        no_kv_cache: bool,
69        non_granular_state: &Option<NonGranularState>,
70        position_ids: &[usize],
71        flash_params: &FlashParams,
72        flash_params_full: &FlashParams,
73    ) -> Result<Tensor> {
74        let (b_size, _) = input_ids_full.dims2()?;
75        let (_, seq_len) = input_ids.dims2()?;
76
77        if let Some(ref non_granular_state) = non_granular_state {
78            if let Some(scalings_cache) = &*self.get_cache().full().get_scalings_cache() {
79                return Ok(scalings_cache.clone());
80            }
81            if seq_len == 1 {
82                *get_mut_arcmutex!(non_granular_state.non_granular_index) += 1;
83            }
84        }
85
86        let dummy_scalings = self.get_classifier().get_dummy_scalings(
87            b_size,
88            seq_len,
89            input_ids.device(),
90            self.dtype(),
91        )?;
92        // Using X-LoRA cache here
93        let hidden_states = if no_kv_cache {
94            let res = self.forward(
95                input_ids_full,
96                seqlen_offsets_full,
97                dummy_scalings,
98                true,
99                no_kv_cache,
100                Some(self.get_classifier().config.scaling_pass_value),
101                position_ids,
102                flash_params_full,
103            )?;
104
105            let mut new_cache = Vec::new();
106            for _ in 0..self.get_cache().full().xlora_lock().len() {
107                new_cache.push(Some((
108                    Tensor::zeros((1,), DType::U8, &Device::Cpu)?,
109                    Tensor::zeros((1,), DType::U8, &Device::Cpu)?,
110                )));
111            }
112            self.get_cache().full().lock().clone_from(&new_cache);
113
114            res
115        } else {
116            self.forward(
117                input_ids,
118                seqlen_offsets,
119                dummy_scalings,
120                false,
121                no_kv_cache,
122                Some(self.get_classifier().config.scaling_pass_value),
123                position_ids,
124                flash_params,
125            )?
126        };
127
128        let scalings = self.get_classifier().forward(hidden_states)?;
129        if let Some(ref non_granular_state) = non_granular_state {
130            if *get_mut_arcmutex!(non_granular_state.non_granular_index)
131                == non_granular_state.tgt_non_granular_index
132            {
133                *self.get_cache().full().get_scalings_cache() = Some(scalings.clone());
134            }
135        }
136        Ok(scalings)
137    }
138}
139
140fn verify_sanity_adapters(ordering: &Ordering, supported_layers: &[&str]) -> Result<()> {
141    if ordering.layers.is_none() {
142        return Ok(());
143    }
144    for path in ordering.layers.as_ref().unwrap().keys() {
145        if !supported_layers.iter().any(|layer| path.ends_with(layer)) {
146            candle_core::bail!("Got a layer name `{path}` in the ordering, expected it to end with one of {supported_layers:?}");
147        }
148    }
149    Ok(())
150}