mistralrs/
lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for LoRA models.
6pub struct LoraModelBuilder {
7    text_model: TextModelBuilder,
8    lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12    pub fn from_text_model_builder(
13        text_model: TextModelBuilder,
14        lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15    ) -> Self {
16        Self {
17            text_model,
18            lora_adapter_ids: lora_adapter_ids
19                .into_iter()
20                .map(|x| x.to_string())
21                .collect(),
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = NormalSpecificConfig {
27            use_flash_attn: self.text_model.use_flash_attn,
28            prompt_chunksize: self.text_model.prompt_chunksize,
29            topology: self.text_model.topology,
30            organization: self.text_model.organization,
31            write_uqff: self.text_model.write_uqff,
32            from_uqff: self.text_model.from_uqff,
33            imatrix: None,
34            calibration_file: None,
35            hf_cache_path: self.text_model.hf_cache_path,
36        };
37
38        if self.text_model.with_logging {
39            initialize_logging();
40        }
41
42        let loader = NormalLoaderBuilder::new(
43            config,
44            self.text_model.chat_template,
45            self.text_model.tokenizer_json,
46            Some(self.text_model.model_id),
47            self.text_model.no_kv_cache,
48            self.text_model.jinja_explicit,
49        )
50        .with_lora(self.lora_adapter_ids)
51        .build(self.text_model.loader_type)?;
52
53        // Load, into a Pipeline
54        let pipeline = loader.load_model_from_hf(
55            self.text_model.hf_revision,
56            self.text_model.token_source,
57            &self.text_model.dtype,
58            &best_device(self.text_model.force_cpu)?,
59            !self.text_model.with_logging,
60            self.text_model
61                .device_mapping
62                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
63            self.text_model.isq,
64            self.text_model.paged_attn_cfg,
65        )?;
66
67        let scheduler_method = match self.text_model.paged_attn_cfg {
68            Some(_) => {
69                let config = pipeline
70                    .lock()
71                    .await
72                    .get_metadata()
73                    .cache_config
74                    .as_ref()
75                    .unwrap()
76                    .clone();
77
78                SchedulerConfig::PagedAttentionMeta {
79                    max_num_seqs: self.text_model.max_num_seqs,
80                    config,
81                }
82            }
83            None => SchedulerConfig::DefaultScheduler {
84                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
85            },
86        };
87
88        let mut runner = MistralRsBuilder::new(
89            pipeline,
90            scheduler_method,
91            self.text_model.throughput_logging,
92            self.text_model.search_bert_model,
93        )
94        .with_no_kv_cache(self.text_model.no_kv_cache)
95        .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
96
97        if let Some(n) = self.text_model.prefix_cache_n {
98            runner = runner.with_prefix_cache_n(n)
99        }
100
101        Ok(Model::new(runner.build()))
102    }
103}