mistralrs/
gguf_xlora_model.rs1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5pub struct GgufXLoraModelBuilder {
7 gguf_model: GgufModelBuilder,
8 xlora_model_id: String,
9 ordering: Ordering,
10 tgt_non_granular_index: Option<usize>,
11}
12
13impl GgufXLoraModelBuilder {
14 pub fn from_gguf_model_builder(
15 gguf_model: GgufModelBuilder,
16 xlora_model_id: impl ToString,
17 ordering: Ordering,
18 ) -> Self {
19 Self {
20 gguf_model,
21 xlora_model_id: xlora_model_id.to_string(),
22 ordering,
23 tgt_non_granular_index: None,
24 }
25 }
26
27 pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28 self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29 self
30 }
31
32 pub async fn build(self) -> anyhow::Result<Model> {
33 let config = GGUFSpecificConfig {
34 topology: self.gguf_model.topology,
35 };
36
37 if self.gguf_model.with_logging {
38 initialize_logging();
39 }
40
41 let loader = GGUFLoaderBuilder::new(
42 self.gguf_model.chat_template,
43 self.gguf_model.tok_model_id,
44 self.gguf_model.model_id,
45 self.gguf_model.files,
46 config,
47 self.gguf_model.no_kv_cache,
48 self.gguf_model.jinja_explicit,
49 )
50 .with_xlora(
51 self.xlora_model_id,
52 self.ordering,
53 self.gguf_model.no_kv_cache,
54 self.tgt_non_granular_index,
55 )
56 .build();
57
58 let pipeline = loader.load_model_from_hf(
60 self.gguf_model.hf_revision,
61 self.gguf_model.token_source,
62 &ModelDType::Auto,
63 &best_device(self.gguf_model.force_cpu)?,
64 !self.gguf_model.with_logging,
65 self.gguf_model
66 .device_mapping
67 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
68 None,
69 self.gguf_model.paged_attn_cfg,
70 )?;
71
72 let scheduler_method = match self.gguf_model.paged_attn_cfg {
73 Some(_) => {
74 let config = pipeline
75 .lock()
76 .await
77 .get_metadata()
78 .cache_config
79 .as_ref()
80 .unwrap()
81 .clone();
82
83 SchedulerConfig::PagedAttentionMeta {
84 max_num_seqs: self.gguf_model.max_num_seqs,
85 config,
86 }
87 }
88 None => SchedulerConfig::DefaultScheduler {
89 method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
90 },
91 };
92
93 let mut runner = MistralRsBuilder::new(
94 pipeline,
95 scheduler_method,
96 self.gguf_model.throughput_logging,
97 self.gguf_model.search_bert_model,
98 );
99 if let Some(cb) = self.gguf_model.search_callback.clone() {
100 runner = runner.with_search_callback(cb);
101 }
102 for (name, cb) in &self.gguf_model.tool_callbacks {
103 runner = runner.with_tool_callback(name.clone(), cb.clone());
104 }
105 runner = runner
106 .with_no_kv_cache(self.gguf_model.no_kv_cache)
107 .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
108
109 if let Some(n) = self.gguf_model.prefix_cache_n {
110 runner = runner.with_prefix_cache_n(n)
111 }
112
113 Ok(Model::new(runner.build().await))
114 }
115}