mistralrs/
lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5/// Wrapper of [`TextModelBuilder`] for LoRA models.
6pub struct LoraModelBuilder {
7    text_model: TextModelBuilder,
8    lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12    pub fn from_text_model_builder(
13        text_model: TextModelBuilder,
14        lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15    ) -> Self {
16        Self {
17            text_model,
18            lora_adapter_ids: lora_adapter_ids
19                .into_iter()
20                .map(|x| x.to_string())
21                .collect(),
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = NormalSpecificConfig {
27            topology: self.text_model.topology,
28            organization: self.text_model.organization,
29            write_uqff: self.text_model.write_uqff,
30            from_uqff: self.text_model.from_uqff,
31            imatrix: None,
32            calibration_file: None,
33            hf_cache_path: self.text_model.hf_cache_path,
34            matformer_config_path: None,
35            matformer_slice_name: None,
36        };
37
38        if self.text_model.with_logging {
39            initialize_logging();
40        }
41
42        let loader = NormalLoaderBuilder::new(
43            config,
44            self.text_model.chat_template,
45            self.text_model.tokenizer_json,
46            Some(self.text_model.model_id),
47            self.text_model.no_kv_cache,
48            self.text_model.jinja_explicit,
49        )
50        .with_lora(self.lora_adapter_ids)
51        .build(self.text_model.loader_type)?;
52
53        // Load, into a Pipeline
54        let pipeline = loader.load_model_from_hf(
55            self.text_model.hf_revision,
56            self.text_model.token_source,
57            &self.text_model.dtype,
58            &best_device(self.text_model.force_cpu)?,
59            !self.text_model.with_logging,
60            self.text_model
61                .device_mapping
62                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
63            self.text_model.isq,
64            self.text_model.paged_attn_cfg,
65        )?;
66
67        let scheduler_method = match self.text_model.paged_attn_cfg {
68            Some(_) => {
69                let config = pipeline
70                    .lock()
71                    .await
72                    .get_metadata()
73                    .cache_config
74                    .as_ref()
75                    .unwrap()
76                    .clone();
77
78                SchedulerConfig::PagedAttentionMeta {
79                    max_num_seqs: self.text_model.max_num_seqs,
80                    config,
81                }
82            }
83            None => SchedulerConfig::DefaultScheduler {
84                method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
85            },
86        };
87
88        let mut runner = MistralRsBuilder::new(
89            pipeline,
90            scheduler_method,
91            self.text_model.throughput_logging,
92            self.text_model.search_bert_model,
93        );
94        if let Some(cb) = self.text_model.search_callback.clone() {
95            runner = runner.with_search_callback(cb);
96        }
97        for (name, cb) in &self.text_model.tool_callbacks {
98            runner = runner.with_tool_callback(name.clone(), cb.clone());
99        }
100        runner = runner
101            .with_no_kv_cache(self.text_model.no_kv_cache)
102            .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
103
104        if let Some(n) = self.text_model.prefix_cache_n {
105            runner = runner.with_prefix_cache_n(n)
106        }
107
108        Ok(Model::new(runner.build().await))
109    }
110}