mistralrs/
xlora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for X-LoRA models.
6pub struct XLoraModelBuilder {
7    text_model: TextModelBuilder,
8    xlora_model_id: String,
9    ordering: Ordering,
10    tgt_non_granular_index: Option<usize>,
11}
12
13impl XLoraModelBuilder {
14    pub fn from_text_model_builder(
15        text_model: TextModelBuilder,
16        xlora_model_id: impl ToString,
17        ordering: Ordering,
18    ) -> Self {
19        Self {
20            text_model,
21            xlora_model_id: xlora_model_id.to_string(),
22            ordering,
23            tgt_non_granular_index: None,
24        }
25    }
26
27    pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28        self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29        self
30    }
31
32    pub async fn build(self) -> anyhow::Result<Model> {
33        let config = NormalSpecificConfig {
34            use_flash_attn: self.text_model.use_flash_attn,
35            prompt_chunksize: self.text_model.prompt_chunksize,
36            topology: self.text_model.topology,
37            organization: self.text_model.organization,
38            write_uqff: self.text_model.write_uqff,
39            from_uqff: self.text_model.from_uqff,
40            imatrix: None,
41            calibration_file: None,
42            hf_cache_path: self.text_model.hf_cache_path,
43        };
44
45        if self.text_model.with_logging {
46            initialize_logging();
47        }
48
49        let loader = NormalLoaderBuilder::new(
50            config,
51            self.text_model.chat_template,
52            self.text_model.tokenizer_json,
53            Some(self.text_model.model_id),
54            self.text_model.no_kv_cache,
55            self.text_model.jinja_explicit,
56        )
57        .with_xlora(
58            self.xlora_model_id,
59            self.ordering,
60            self.text_model.no_kv_cache,
61            self.tgt_non_granular_index,
62        )
63        .build(self.text_model.loader_type)?;
64
65        // Load, into a Pipeline
66        let pipeline = loader.load_model_from_hf(
67            self.text_model.hf_revision,
68            self.text_model.token_source,
69            &self.text_model.dtype,
70            &best_device(self.text_model.force_cpu)?,
71            !self.text_model.with_logging,
72            self.text_model
73                .device_mapping
74                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
75            self.text_model.isq,
76            self.text_model.paged_attn_cfg,
77        )?;
78
79        let scheduler_method = match self.text_model.paged_attn_cfg {
80            Some(_) => {
81                let config = pipeline
82                    .lock()
83                    .await
84                    .get_metadata()
85                    .cache_config
86                    .as_ref()
87                    .unwrap()
88                    .clone();
89
90                SchedulerConfig::PagedAttentionMeta {
91                    max_num_seqs: self.text_model.max_num_seqs,
92                    config,
93                }
94            }
95            None => SchedulerConfig::DefaultScheduler {
96                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
97            },
98        };
99
100        let mut runner = MistralRsBuilder::new(
101            pipeline,
102            scheduler_method,
103            self.text_model.throughput_logging,
104            self.text_model.search_bert_model,
105        )
106        .with_no_kv_cache(self.text_model.no_kv_cache)
107        .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
108
109        if let Some(n) = self.text_model.prefix_cache_n {
110            runner = runner.with_prefix_cache_n(n)
111        }
112
113        Ok(Model::new(runner.build()))
114    }
115}