mistralrs/
lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for LoRA models.
6pub struct LoraModelBuilder {
7    text_model: TextModelBuilder,
8    lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12    pub fn from_text_model_builder(
13        text_model: TextModelBuilder,
14        lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15    ) -> Self {
16        Self {
17            text_model,
18            lora_adapter_ids: lora_adapter_ids
19                .into_iter()
20                .map(|x| x.to_string())
21                .collect(),
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = NormalSpecificConfig {
27            prompt_chunksize: self.text_model.prompt_chunksize,
28            topology: self.text_model.topology,
29            organization: self.text_model.organization,
30            write_uqff: self.text_model.write_uqff,
31            from_uqff: self.text_model.from_uqff,
32            imatrix: None,
33            calibration_file: None,
34            hf_cache_path: self.text_model.hf_cache_path,
35            matformer_config_path: None,
36            matformer_slice_name: None,
37        };
38
39        if self.text_model.with_logging {
40            initialize_logging();
41        }
42
43        let loader = NormalLoaderBuilder::new(
44            config,
45            self.text_model.chat_template,
46            self.text_model.tokenizer_json,
47            Some(self.text_model.model_id),
48            self.text_model.no_kv_cache,
49            self.text_model.jinja_explicit,
50        )
51        .with_lora(self.lora_adapter_ids)
52        .build(self.text_model.loader_type)?;
53
54        // Load, into a Pipeline
55        let pipeline = loader.load_model_from_hf(
56            self.text_model.hf_revision,
57            self.text_model.token_source,
58            &self.text_model.dtype,
59            &best_device(self.text_model.force_cpu)?,
60            !self.text_model.with_logging,
61            self.text_model
62                .device_mapping
63                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
64            self.text_model.isq,
65            self.text_model.paged_attn_cfg,
66        )?;
67
68        let scheduler_method = match self.text_model.paged_attn_cfg {
69            Some(_) => {
70                let config = pipeline
71                    .lock()
72                    .await
73                    .get_metadata()
74                    .cache_config
75                    .as_ref()
76                    .unwrap()
77                    .clone();
78
79                SchedulerConfig::PagedAttentionMeta {
80                    max_num_seqs: self.text_model.max_num_seqs,
81                    config,
82                }
83            }
84            None => SchedulerConfig::DefaultScheduler {
85                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
86            },
87        };
88
89        let mut runner = MistralRsBuilder::new(
90            pipeline,
91            scheduler_method,
92            self.text_model.throughput_logging,
93            self.text_model.search_bert_model,
94        );
95        if let Some(cb) = self.text_model.search_callback.clone() {
96            runner = runner.with_search_callback(cb);
97        }
98        for (name, cb) in &self.text_model.tool_callbacks {
99            runner = runner.with_tool_callback(name.clone(), cb.clone());
100        }
101        runner = runner
102            .with_no_kv_cache(self.text_model.no_kv_cache)
103            .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
104
105        if let Some(n) = self.text_model.prefix_cache_n {
106            runner = runner.with_prefix_cache_n(n)
107        }
108
109        Ok(Model::new(runner.build().await))
110    }
111}