mistralrs/
gguf_xlora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5/// Wrapper of [`GgufModelBuilder`] for X-LoRA models.
6pub struct GgufXLoraModelBuilder {
7    gguf_model: GgufModelBuilder,
8    xlora_model_id: String,
9    ordering: Ordering,
10    tgt_non_granular_index: Option<usize>,
11}
12
13impl GgufXLoraModelBuilder {
14    pub fn from_gguf_model_builder(
15        gguf_model: GgufModelBuilder,
16        xlora_model_id: impl ToString,
17        ordering: Ordering,
18    ) -> Self {
19        Self {
20            gguf_model,
21            xlora_model_id: xlora_model_id.to_string(),
22            ordering,
23            tgt_non_granular_index: None,
24        }
25    }
26
27    pub fn tgt_non_granular_index(mut self, tgt_non_granular_idx: usize) -> Self {
28        self.tgt_non_granular_index = Some(tgt_non_granular_idx);
29        self
30    }
31
32    pub async fn build(self) -> anyhow::Result<Model> {
33        let config = GGUFSpecificConfig {
34            prompt_chunksize: self.gguf_model.prompt_chunksize,
35            topology: self.gguf_model.topology,
36        };
37
38        if self.gguf_model.with_logging {
39            initialize_logging();
40        }
41
42        let loader = GGUFLoaderBuilder::new(
43            self.gguf_model.chat_template,
44            self.gguf_model.tok_model_id,
45            self.gguf_model.model_id,
46            self.gguf_model.files,
47            config,
48            self.gguf_model.no_kv_cache,
49            self.gguf_model.jinja_explicit,
50        )
51        .with_xlora(
52            self.xlora_model_id,
53            self.ordering,
54            self.gguf_model.no_kv_cache,
55            self.tgt_non_granular_index,
56        )
57        .build();
58
59        // Load, into a Pipeline
60        let pipeline = loader.load_model_from_hf(
61            self.gguf_model.hf_revision,
62            self.gguf_model.token_source,
63            &ModelDType::Auto,
64            &best_device(self.gguf_model.force_cpu)?,
65            !self.gguf_model.with_logging,
66            self.gguf_model
67                .device_mapping
68                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
69            None,
70            self.gguf_model.paged_attn_cfg,
71        )?;
72
73        let scheduler_method = match self.gguf_model.paged_attn_cfg {
74            Some(_) => {
75                let config = pipeline
76                    .lock()
77                    .await
78                    .get_metadata()
79                    .cache_config
80                    .as_ref()
81                    .unwrap()
82                    .clone();
83
84                SchedulerConfig::PagedAttentionMeta {
85                    max_num_seqs: self.gguf_model.max_num_seqs,
86                    config,
87                }
88            }
89            None => SchedulerConfig::DefaultScheduler {
90                method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
91            },
92        };
93
94        let mut runner = MistralRsBuilder::new(
95            pipeline,
96            scheduler_method,
97            self.gguf_model.throughput_logging,
98            self.gguf_model.search_bert_model,
99        )
100        .with_no_kv_cache(self.gguf_model.no_kv_cache)
101        .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
102
103        if let Some(n) = self.gguf_model.prefix_cache_n {
104            runner = runner.with_prefix_cache_n(n)
105        }
106
107        Ok(Model::new(runner.build()))
108    }
109}