mistralrs/
lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for LoRA models.
6pub struct LoraModelBuilder {
7    text_model: TextModelBuilder,
8    lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12    pub fn from_text_model_builder(
13        text_model: TextModelBuilder,
14        lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15    ) -> Self {
16        Self {
17            text_model,
18            lora_adapter_ids: lora_adapter_ids
19                .into_iter()
20                .map(|x| x.to_string())
21                .collect(),
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = NormalSpecificConfig {
27            prompt_chunksize: self.text_model.prompt_chunksize,
28            topology: self.text_model.topology,
29            organization: self.text_model.organization,
30            write_uqff: self.text_model.write_uqff,
31            from_uqff: self.text_model.from_uqff,
32            imatrix: None,
33            calibration_file: None,
34            hf_cache_path: self.text_model.hf_cache_path,
35        };
36
37        if self.text_model.with_logging {
38            initialize_logging();
39        }
40
41        let loader = NormalLoaderBuilder::new(
42            config,
43            self.text_model.chat_template,
44            self.text_model.tokenizer_json,
45            Some(self.text_model.model_id),
46            self.text_model.no_kv_cache,
47            self.text_model.jinja_explicit,
48        )
49        .with_lora(self.lora_adapter_ids)
50        .build(self.text_model.loader_type)?;
51
52        // Load, into a Pipeline
53        let pipeline = loader.load_model_from_hf(
54            self.text_model.hf_revision,
55            self.text_model.token_source,
56            &self.text_model.dtype,
57            &best_device(self.text_model.force_cpu)?,
58            !self.text_model.with_logging,
59            self.text_model
60                .device_mapping
61                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
62            self.text_model.isq,
63            self.text_model.paged_attn_cfg,
64        )?;
65
66        let scheduler_method = match self.text_model.paged_attn_cfg {
67            Some(_) => {
68                let config = pipeline
69                    .lock()
70                    .await
71                    .get_metadata()
72                    .cache_config
73                    .as_ref()
74                    .unwrap()
75                    .clone();
76
77                SchedulerConfig::PagedAttentionMeta {
78                    max_num_seqs: self.text_model.max_num_seqs,
79                    config,
80                }
81            }
82            None => SchedulerConfig::DefaultScheduler {
83                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
84            },
85        };
86
87        let mut runner = MistralRsBuilder::new(
88            pipeline,
89            scheduler_method,
90            self.text_model.throughput_logging,
91            self.text_model.search_bert_model,
92        )
93        .with_no_kv_cache(self.text_model.no_kv_cache)
94        .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
95
96        if let Some(n) = self.text_model.prefix_cache_n {
97            runner = runner.with_prefix_cache_n(n)
98        }
99
100        Ok(Model::new(runner.build()))
101    }
102}