mistralrs/
gguf_xlora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5/// Wrapper of [`GgufModelBuilder`] for X-LoRA models.
6pub struct GgufXLoraModelBuilder {
7    gguf_model: GgufModelBuilder,
8    xlora_model_id: String,
9    ordering: Ordering,
10    tgt_non_granular_index: Option<usize>,
11}
12
13impl GgufXLoraModelBuilder {
14    pub fn from_gguf_model_builder(
15        gguf_model: GgufModelBuilder,
16        xlora_model_id: impl ToString,
17        ordering: Ordering,
18    ) -> Self {
19        Self {
20            gguf_model,
21            xlora_model_id: xlora_model_id.to_string(),
22            ordering,
23            tgt_non_granular_index: None,
24        }
25    }
26
27    pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28        self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29        self
30    }
31
32    pub async fn build(self) -> anyhow::Result<Model> {
33        let config = GGUFSpecificConfig {
34            topology: self.gguf_model.topology,
35        };
36
37        if self.gguf_model.with_logging {
38            initialize_logging();
39        }
40
41        let loader = GGUFLoaderBuilder::new(
42            self.gguf_model.chat_template,
43            self.gguf_model.tok_model_id,
44            self.gguf_model.model_id,
45            self.gguf_model.files,
46            config,
47            self.gguf_model.no_kv_cache,
48            self.gguf_model.jinja_explicit,
49        )
50        .with_xlora(
51            self.xlora_model_id,
52            self.ordering,
53            self.gguf_model.no_kv_cache,
54            self.tgt_non_granular_index,
55        )
56        .build();
57
58        // Load, into a Pipeline
59        let pipeline = loader.load_model_from_hf(
60            self.gguf_model.hf_revision,
61            self.gguf_model.token_source,
62            &ModelDType::Auto,
63            &best_device(self.gguf_model.force_cpu)?,
64            !self.gguf_model.with_logging,
65            self.gguf_model
66                .device_mapping
67                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
68            None,
69            self.gguf_model.paged_attn_cfg,
70        )?;
71
72        let scheduler_method = match self.gguf_model.paged_attn_cfg {
73            Some(_) => {
74                let config = pipeline
75                    .lock()
76                    .await
77                    .get_metadata()
78                    .cache_config
79                    .as_ref()
80                    .unwrap()
81                    .clone();
82
83                SchedulerConfig::PagedAttentionMeta {
84                    max_num_seqs: self.gguf_model.max_num_seqs,
85                    config,
86                }
87            }
88            None => SchedulerConfig::DefaultScheduler {
89                method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
90            },
91        };
92
93        let mut runner = MistralRsBuilder::new(
94            pipeline,
95            scheduler_method,
96            self.gguf_model.throughput_logging,
97            self.gguf_model.search_bert_model,
98        );
99        if let Some(cb) = self.gguf_model.search_callback.clone() {
100            runner = runner.with_search_callback(cb);
101        }
102        for (name, cb) in &self.gguf_model.tool_callbacks {
103            runner = runner.with_tool_callback(name.clone(), cb.clone());
104        }
105        runner = runner
106            .with_no_kv_cache(self.gguf_model.no_kv_cache)
107            .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
108
109        if let Some(n) = self.gguf_model.prefix_cache_n {
110            runner = runner.with_prefix_cache_n(n)
111        }
112
113        Ok(Model::new(runner.build().await))
114    }
115}