1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5pub struct LoraModelBuilder {
7 text_model: TextModelBuilder,
8 lora_adapter_ids: Vec<String>,
9}
10
11impl LoraModelBuilder {
12 pub fn from_text_model_builder(
13 text_model: TextModelBuilder,
14 lora_adapter_ids: impl IntoIterator<Item = impl ToString>,
15 ) -> Self {
16 Self {
17 text_model,
18 lora_adapter_ids: lora_adapter_ids
19 .into_iter()
20 .map(|x| x.to_string())
21 .collect(),
22 }
23 }
24
25 pub async fn build(self) -> anyhow::Result<Model> {
26 let config = NormalSpecificConfig {
27 prompt_chunksize: self.text_model.prompt_chunksize,
28 topology: self.text_model.topology,
29 organization: self.text_model.organization,
30 write_uqff: self.text_model.write_uqff,
31 from_uqff: self.text_model.from_uqff,
32 imatrix: None,
33 calibration_file: None,
34 hf_cache_path: self.text_model.hf_cache_path,
35 matformer_config_path: None,
36 matformer_slice_name: None,
37 };
38
39 if self.text_model.with_logging {
40 initialize_logging();
41 }
42
43 let loader = NormalLoaderBuilder::new(
44 config,
45 self.text_model.chat_template,
46 self.text_model.tokenizer_json,
47 Some(self.text_model.model_id),
48 self.text_model.no_kv_cache,
49 self.text_model.jinja_explicit,
50 )
51 .with_lora(self.lora_adapter_ids)
52 .build(self.text_model.loader_type)?;
53
54 let pipeline = loader.load_model_from_hf(
56 self.text_model.hf_revision,
57 self.text_model.token_source,
58 &self.text_model.dtype,
59 &best_device(self.text_model.force_cpu)?,
60 !self.text_model.with_logging,
61 self.text_model
62 .device_mapping
63 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
64 self.text_model.isq,
65 self.text_model.paged_attn_cfg,
66 )?;
67
68 let scheduler_method = match self.text_model.paged_attn_cfg {
69 Some(_) => {
70 let config = pipeline
71 .lock()
72 .await
73 .get_metadata()
74 .cache_config
75 .as_ref()
76 .unwrap()
77 .clone();
78
79 SchedulerConfig::PagedAttentionMeta {
80 max_num_seqs: self.text_model.max_num_seqs,
81 config,
82 }
83 }
84 None => SchedulerConfig::DefaultScheduler {
85 method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
86 },
87 };
88
89 let mut runner = MistralRsBuilder::new(
90 pipeline,
91 scheduler_method,
92 self.text_model.throughput_logging,
93 self.text_model.search_bert_model,
94 );
95 if let Some(cb) = self.text_model.search_callback.clone() {
96 runner = runner.with_search_callback(cb);
97 }
98 for (name, cb) in &self.text_model.tool_callbacks {
99 runner = runner.with_tool_callback(name.clone(), cb.clone());
100 }
101 runner = runner
102 .with_no_kv_cache(self.text_model.no_kv_cache)
103 .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
104
105 if let Some(n) = self.text_model.prefix_cache_n {
106 runner = runner.with_prefix_cache_n(n)
107 }
108
109 Ok(Model::new(runner.build().await))
110 }
111}