1use mistralrs_core::*;
2
3use crate::{best_device, Model, TextModelBuilder};
4
5pub struct XLoraModelBuilder {
7 text_model: TextModelBuilder,
8 xlora_model_id: String,
9 ordering: Ordering,
10 tgt_non_granular_index: Option<usize>,
11}
12
13impl XLoraModelBuilder {
14 pub fn from_text_model_builder(
15 text_model: TextModelBuilder,
16 xlora_model_id: impl ToString,
17 ordering: Ordering,
18 ) -> Self {
19 Self {
20 text_model,
21 xlora_model_id: xlora_model_id.to_string(),
22 ordering,
23 tgt_non_granular_index: None,
24 }
25 }
26
27 pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28 self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29 self
30 }
31
32 pub async fn build(self) -> anyhow::Result<Model> {
33 let config = NormalSpecificConfig {
34 prompt_chunksize: self.text_model.prompt_chunksize,
35 topology: self.text_model.topology,
36 organization: self.text_model.organization,
37 write_uqff: self.text_model.write_uqff,
38 from_uqff: self.text_model.from_uqff,
39 imatrix: None,
40 calibration_file: None,
41 hf_cache_path: self.text_model.hf_cache_path,
42 matformer_config_path: None,
43 matformer_slice_name: None,
44 };
45
46 if self.text_model.with_logging {
47 initialize_logging();
48 }
49
50 let loader = NormalLoaderBuilder::new(
51 config,
52 self.text_model.chat_template,
53 self.text_model.tokenizer_json,
54 Some(self.text_model.model_id),
55 self.text_model.no_kv_cache,
56 self.text_model.jinja_explicit,
57 )
58 .with_xlora(
59 self.xlora_model_id,
60 self.ordering,
61 self.text_model.no_kv_cache,
62 self.tgt_non_granular_index,
63 )
64 .build(self.text_model.loader_type)?;
65
66 let pipeline = loader.load_model_from_hf(
68 self.text_model.hf_revision,
69 self.text_model.token_source,
70 &self.text_model.dtype,
71 &best_device(self.text_model.force_cpu)?,
72 !self.text_model.with_logging,
73 self.text_model
74 .device_mapping
75 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
76 self.text_model.isq,
77 self.text_model.paged_attn_cfg,
78 )?;
79
80 let scheduler_method = match self.text_model.paged_attn_cfg {
81 Some(_) => {
82 let config = pipeline
83 .lock()
84 .await
85 .get_metadata()
86 .cache_config
87 .as_ref()
88 .unwrap()
89 .clone();
90
91 SchedulerConfig::PagedAttentionMeta {
92 max_num_seqs: self.text_model.max_num_seqs,
93 config,
94 }
95 }
96 None => SchedulerConfig::DefaultScheduler {
97 method: DefaultSchedulerMethod::Fixed(self.text_model.max_num_seqs.try_into()?),
98 },
99 };
100
101 let mut runner = MistralRsBuilder::new(
102 pipeline,
103 scheduler_method,
104 self.text_model.throughput_logging,
105 self.text_model.search_bert_model,
106 );
107 if let Some(cb) = self.text_model.search_callback.clone() {
108 runner = runner.with_search_callback(cb);
109 }
110 for (name, cb) in &self.text_model.tool_callbacks {
111 runner = runner.with_tool_callback(name.clone(), cb.clone());
112 }
113 runner = runner
114 .with_no_kv_cache(self.text_model.no_kv_cache)
115 .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
116
117 if let Some(n) = self.text_model.prefix_cache_n {
118 runner = runner.with_prefix_cache_n(n)
119 }
120
121 Ok(Model::new(runner.build().await))
122 }
123}