mistralrs_core/xlora_models/
mod.rs1mod classifier;
2mod config;
3mod gemma;
4mod gemma2;
5mod llama;
6mod mistral;
7mod mixtral;
8mod phi2;
9mod phi3;
10mod quantized_llama;
11mod quantized_phi3;
12mod starcoder2;
13
14use std::sync::Arc;
15
16use crate::{
17 lora::Ordering,
18 pipeline::{text_models_inputs_processor::FlashParams, EitherCache},
19};
20use candle_core::{DType, Device, Result, Tensor};
21pub(crate) use config::XLoraConfig;
22pub(crate) use gemma::XLoraModel as XLoraGemma;
23pub(crate) use gemma2::Model as XLoraGemma2;
24pub(crate) use llama::XLoraLlama;
25pub(crate) use mistral::XLoraModel as XLoraMistral;
26pub(crate) use mixtral::XLoraModel as XLoraMixtral;
27pub(crate) use phi2::Model as XLoraPhi2;
28pub(crate) use phi3::Model as XLoraPhi3;
29pub(crate) use quantized_llama::ModelWeights as XLoraQLlama;
30pub(crate) use quantized_phi3::ModelWeights as XLoraQPhi3;
31pub(crate) use starcoder2::Model as XLoraStarcoder2;
32use tokio::sync::Mutex;
33
34use crate::{get_mut_arcmutex, pipeline::Cache};
35
36use self::classifier::XLoraClassifier;
37
38pub struct NonGranularState {
39 pub non_granular_index: Arc<Mutex<usize>>,
40 pub tgt_non_granular_index: usize,
41}
42
43trait ScalingsMaker {
44 fn get_classifier(&self) -> &XLoraClassifier;
45 fn dtype(&self) -> DType;
47 #[allow(clippy::too_many_arguments)]
48 fn forward(
49 &self,
50 input_ids: &Tensor,
51 seqlen_offsets: &[usize],
52 scalings: Tensor,
53 is_full_pass: bool,
54 no_kv_cache: bool,
55 is_scaling_pass: Option<f64>,
56 context_lens: &[usize],
57 flash_params: &FlashParams,
58 ) -> Result<Tensor>;
59 fn get_cache(&self) -> &EitherCache;
60
61 #[allow(clippy::too_many_arguments)]
62 fn get_scalings(
63 &self,
64 input_ids: &Tensor,
65 input_ids_full: &Tensor,
66 seqlen_offsets: &[usize],
67 seqlen_offsets_full: &[usize],
68 no_kv_cache: bool,
69 non_granular_state: &Option<NonGranularState>,
70 position_ids: &[usize],
71 flash_params: &FlashParams,
72 flash_params_full: &FlashParams,
73 ) -> Result<Tensor> {
74 let (b_size, _) = input_ids_full.dims2()?;
75 let (_, seq_len) = input_ids.dims2()?;
76
77 if let Some(ref non_granular_state) = non_granular_state {
78 if let Some(scalings_cache) = &*self.get_cache().full().get_scalings_cache() {
79 return Ok(scalings_cache.clone());
80 }
81 if seq_len == 1 {
82 *get_mut_arcmutex!(non_granular_state.non_granular_index) += 1;
83 }
84 }
85
86 let dummy_scalings = self.get_classifier().get_dummy_scalings(
87 b_size,
88 seq_len,
89 input_ids.device(),
90 self.dtype(),
91 )?;
92 let hidden_states = if no_kv_cache {
94 let res = self.forward(
95 input_ids_full,
96 seqlen_offsets_full,
97 dummy_scalings,
98 true,
99 no_kv_cache,
100 Some(self.get_classifier().config.scaling_pass_value),
101 position_ids,
102 flash_params_full,
103 )?;
104
105 let mut new_cache = Vec::new();
106 for _ in 0..self.get_cache().full().xlora_lock().len() {
107 new_cache.push(Some((
108 Tensor::zeros((1,), DType::U8, &Device::Cpu)?,
109 Tensor::zeros((1,), DType::U8, &Device::Cpu)?,
110 )));
111 }
112 self.get_cache().full().lock().clone_from(&new_cache);
113
114 res
115 } else {
116 self.forward(
117 input_ids,
118 seqlen_offsets,
119 dummy_scalings,
120 false,
121 no_kv_cache,
122 Some(self.get_classifier().config.scaling_pass_value),
123 position_ids,
124 flash_params,
125 )?
126 };
127
128 let scalings = self.get_classifier().forward(hidden_states)?;
129 if let Some(ref non_granular_state) = non_granular_state {
130 if *get_mut_arcmutex!(non_granular_state.non_granular_index)
131 == non_granular_state.tgt_non_granular_index
132 {
133 *self.get_cache().full().get_scalings_cache() = Some(scalings.clone());
134 }
135 }
136 Ok(scalings)
137 }
138}
139
140fn verify_sanity_adapters(ordering: &Ordering, supported_layers: &[&str]) -> Result<()> {
141 if ordering.layers.is_none() {
142 return Ok(());
143 }
144 for path in ordering.layers.as_ref().unwrap().keys() {
145 if !supported_layers.iter().any(|layer| path.ends_with(layer)) {
146 candle_core::bail!("Got a layer name `{path}` in the ordering, expected it to end with one of {supported_layers:?}");
147 }
148 }
149 Ok(())
150}