1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5pub struct LoraModelBuilder {
7 text_model: TextModelBuilder,
8 lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12 pub fn from_text_model_builder(
13 text_model: TextModelBuilder,
14 lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15 ) -> Self {
16 Self {
17 text_model,
18 lora_adapter_ids: lora_adapter_ids
19 .into_iter()
20 .map(|x| x.to_string())
21 .collect(),
22 }
23 }
24
25 pub async fn build(self) -> anyhow::Result<Model> {
26 let config = NormalSpecificConfig {
27 use_flash_attn: self.text_model.use_flash_attn,
28 prompt_chunksize: self.text_model.prompt_chunksize,
29 topology: self.text_model.topology,
30 organization: self.text_model.organization,
31 write_uqff: self.text_model.write_uqff,
32 from_uqff: self.text_model.from_uqff,
33 imatrix: None,
34 calibration_file: None,
35 hf_cache_path: self.text_model.hf_cache_path,
36 };
37
38 if self.text_model.with_logging {
39 initialize_logging();
40 }
41
42 let loader = NormalLoaderBuilder::new(
43 config,
44 self.text_model.chat_template,
45 self.text_model.tokenizer_json,
46 Some(self.text_model.model_id),
47 self.text_model.no_kv_cache,
48 self.text_model.jinja_explicit,
49 )
50 .with_lora(self.lora_adapter_ids)
51 .build(self.text_model.loader_type)?;
52
53 let pipeline = loader.load_model_from_hf(
55 self.text_model.hf_revision,
56 self.text_model.token_source,
57 &self.text_model.dtype,
58 &best_device(self.text_model.force_cpu)?,
59 !self.text_model.with_logging,
60 self.text_model
61 .device_mapping
62 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
63 self.text_model.isq,
64 self.text_model.paged_attn_cfg,
65 )?;
66
67 let scheduler_method = match self.text_model.paged_attn_cfg {
68 Some(_) => {
69 let config = pipeline
70 .lock()
71 .await
72 .get_metadata()
73 .cache_config
74 .as_ref()
75 .unwrap()
76 .clone();
77
78 SchedulerConfig::PagedAttentionMeta {
79 max_num_seqs: self.text_model.max_num_seqs,
80 config,
81 }
82 }
83 None => SchedulerConfig::DefaultScheduler {
84 method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
85 },
86 };
87
88 let mut runner = MistralRsBuilder::new(
89 pipeline,
90 scheduler_method,
91 self.text_model.throughput_logging,
92 self.text_model.search_bert_model,
93 )
94 .with_no_kv_cache(self.text_model.no_kv_cache)
95 .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
96
97 if let Some(n) = self.text_model.prefix_cache_n {
98 runner = runner.with_prefix_cache_n(n)
99 }
100
101 Ok(Model::new(runner.build()))
102 }
103}