1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5pub struct LoraModelBuilder {
7 text_model: TextModelBuilder,
8 lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12 pub fn from_text_model_builder(
13 text_model: TextModelBuilder,
14 lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15 ) -> Self {
16 Self {
17 text_model,
18 lora_adapter_ids: lora_adapter_ids
19 .into_iter()
20 .map(|x| x.to_string())
21 .collect(),
22 }
23 }
24
25 pub async fn build(self) -> anyhow::Result<Model> {
26 let config = NormalSpecificConfig {
27 prompt_chunksize: self.text_model.prompt_chunksize,
28 topology: self.text_model.topology,
29 organization: self.text_model.organization,
30 write_uqff: self.text_model.write_uqff,
31 from_uqff: self.text_model.from_uqff,
32 imatrix: None,
33 calibration_file: None,
34 hf_cache_path: self.text_model.hf_cache_path,
35 };
36
37 if self.text_model.with_logging {
38 initialize_logging();
39 }
40
41 let loader = NormalLoaderBuilder::new(
42 config,
43 self.text_model.chat_template,
44 self.text_model.tokenizer_json,
45 Some(self.text_model.model_id),
46 self.text_model.no_kv_cache,
47 self.text_model.jinja_explicit,
48 )
49 .with_lora(self.lora_adapter_ids)
50 .build(self.text_model.loader_type)?;
51
52 let pipeline = loader.load_model_from_hf(
54 self.text_model.hf_revision,
55 self.text_model.token_source,
56 &self.text_model.dtype,
57 &best_device(self.text_model.force_cpu)?,
58 !self.text_model.with_logging,
59 self.text_model
60 .device_mapping
61 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
62 self.text_model.isq,
63 self.text_model.paged_attn_cfg,
64 )?;
65
66 let scheduler_method = match self.text_model.paged_attn_cfg {
67 Some(_) => {
68 let config = pipeline
69 .lock()
70 .await
71 .get_metadata()
72 .cache_config
73 .as_ref()
74 .unwrap()
75 .clone();
76
77 SchedulerConfig::PagedAttentionMeta {
78 max_num_seqs: self.text_model.max_num_seqs,
79 config,
80 }
81 }
82 None => SchedulerConfig::DefaultScheduler {
83 method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
84 },
85 };
86
87 let mut runner = MistralRsBuilder::new(
88 pipeline,
89 scheduler_method,
90 self.text_model.throughput_logging,
91 self.text_model.search_bert_model,
92 )
93 .with_no_kv_cache(self.text_model.no_kv_cache)
94 .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
95
96 if let Some(n) = self.text_model.prefix_cache_n {
97 runner = runner.with_prefix_cache_n(n)
98 }
99
100 Ok(Model::new(runner.build()))
101 }
102}