mistralrs/
xlora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for X-LoRA models.
6pub struct XLoraModelBuilder {
7    text_model: TextModelBuilder,
8    xlora_model_id: String,
9    ordering: Ordering,
10    tgt_non_granular_index: Option<usize>,
11}
12
13impl XLoraModelBuilder {
14    pub fn from_text_model_builder(
15        text_model: TextModelBuilder,
16        xlora_model_id: impl ToString,
17        ordering: Ordering,
18    ) -> Self {
19        Self {
20            text_model,
21            xlora_model_id: xlora_model_id.to_string(),
22            ordering,
23            tgt_non_granular_index: None,
24        }
25    }
26
27    pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28        self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29        self
30    }
31
32    pub async fn build(self) -> anyhow::Result<Model> {
33        let config = NormalSpecificConfig {
34            prompt_chunksize: self.text_model.prompt_chunksize,
35            topology: self.text_model.topology,
36            organization: self.text_model.organization,
37            write_uqff: self.text_model.write_uqff,
38            from_uqff: self.text_model.from_uqff,
39            imatrix: None,
40            calibration_file: None,
41            hf_cache_path: self.text_model.hf_cache_path,
42            matformer_config_path: None,
43            matformer_slice_name: None,
44        };
45
46        if self.text_model.with_logging {
47            initialize_logging();
48        }
49
50        let loader = NormalLoaderBuilder::new(
51            config,
52            self.text_model.chat_template,
53            self.text_model.tokenizer_json,
54            Some(self.text_model.model_id),
55            self.text_model.no_kv_cache,
56            self.text_model.jinja_explicit,
57        )
58        .with_xlora(
59            self.xlora_model_id,
60            self.ordering,
61            self.text_model.no_kv_cache,
62            self.tgt_non_granular_index,
63        )
64        .build(self.text_model.loader_type)?;
65
66        // Load, into a Pipeline
67        let pipeline = loader.load_model_from_hf(
68            self.text_model.hf_revision,
69            self.text_model.token_source,
70            &self.text_model.dtype,
71            &best_device(self.text_model.force_cpu)?,
72            !self.text_model.with_logging,
73            self.text_model
74                .device_mapping
75                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
76            self.text_model.isq,
77            self.text_model.paged_attn_cfg,
78        )?;
79
80        let scheduler_method = match self.text_model.paged_attn_cfg {
81            Some(_) => {
82                let config = pipeline
83                    .lock()
84                    .await
85                    .get_metadata()
86                    .cache_config
87                    .as_ref()
88                    .unwrap()
89                    .clone();
90
91                SchedulerConfig::PagedAttentionMeta {
92                    max_num_seqs: self.text_model.max_num_seqs,
93                    config,
94                }
95            }
96            None => SchedulerConfig::DefaultScheduler {
97                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
98            },
99        };
100
101        let mut runner = MistralRsBuilder::new(
102            pipeline,
103            scheduler_method,
104            self.text_model.throughput_logging,
105            self.text_model.search_bert_model,
106        );
107        if let Some(cb) = self.text_model.search_callback.clone() {
108            runner = runner.with_search_callback(cb);
109        }
110        for (name, cb) in &self.text_model.tool_callbacks {
111            runner = runner.with_tool_callback(name.clone(), cb.clone());
112        }
113        runner = runner
114            .with_no_kv_cache(self.text_model.no_kv_cache)
115            .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
116
117        if let Some(n) = self.text_model.prefix_cache_n {
118            runner = runner.with_prefix_cache_n(n)
119        }
120
121        Ok(Model::new(runner.build().await))
122    }
123}