mistralrs/
xlora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for X-LoRA models.
6pub struct XLoraModelBuilder {
7    text_model: TextModelBuilder,
8    xlora_model_id: String,
9    ordering: Ordering,
10    tgt_non_granular_index: Option<usize>,
11}
12
13impl XLoraModelBuilder {
14    pub fn from_text_model_builder(
15        text_model: TextModelBuilder,
16        xlora_model_id: impl ToString,
17        ordering: Ordering,
18    ) -> Self {
19        Self {
20            text_model,
21            xlora_model_id: xlora_model_id.to_string(),
22            ordering,
23            tgt_non_granular_index: None,
24        }
25    }
26
27    pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28        self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29        self
30    }
31
32    pub async fn build(self) -> anyhow::Result<Model> {
33        let config = NormalSpecificConfig {
34            prompt_chunksize: self.text_model.prompt_chunksize,
35            topology: self.text_model.topology,
36            organization: self.text_model.organization,
37            write_uqff: self.text_model.write_uqff,
38            from_uqff: self.text_model.from_uqff,
39            imatrix: None,
40            calibration_file: None,
41            hf_cache_path: self.text_model.hf_cache_path,
42        };
43
44        if self.text_model.with_logging {
45            initialize_logging();
46        }
47
48        let loader = NormalLoaderBuilder::new(
49            config,
50            self.text_model.chat_template,
51            self.text_model.tokenizer_json,
52            Some(self.text_model.model_id),
53            self.text_model.no_kv_cache,
54            self.text_model.jinja_explicit,
55        )
56        .with_xlora(
57            self.xlora_model_id,
58            self.ordering,
59            self.text_model.no_kv_cache,
60            self.tgt_non_granular_index,
61        )
62        .build(self.text_model.loader_type)?;
63
64        // Load, into a Pipeline
65        let pipeline = loader.load_model_from_hf(
66            self.text_model.hf_revision,
67            self.text_model.token_source,
68            &self.text_model.dtype,
69            &best_device(self.text_model.force_cpu)?,
70            !self.text_model.with_logging,
71            self.text_model
72                .device_mapping
73                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
74            self.text_model.isq,
75            self.text_model.paged_attn_cfg,
76        )?;
77
78        let scheduler_method = match self.text_model.paged_attn_cfg {
79            Some(_) => {
80                let config = pipeline
81                    .lock()
82                    .await
83                    .get_metadata()
84                    .cache_config
85                    .as_ref()
86                    .unwrap()
87                    .clone();
88
89                SchedulerConfig::PagedAttentionMeta {
90                    max_num_seqs: self.text_model.max_num_seqs,
91                    config,
92                }
93            }
94            None => SchedulerConfig::DefaultScheduler {
95                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
96            },
97        };
98
99        let mut runner = MistralRsBuilder::new(
100            pipeline,
101            scheduler_method,
102            self.text_model.throughput_logging,
103            self.text_model.search_bert_model,
104        )
105        .with_no_kv_cache(self.text_model.no_kv_cache)
106        .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
107
108        if let Some(n) = self.text_model.prefix_cache_n {
109            runner = runner.with_prefix_cache_n(n)
110        }
111
112        Ok(Model::new(runner.build()))
113    }
114}