mistralrs/
gguf_lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5/// Wrapper of [`GgufModelBuilder`] for LoRA models.
6pub struct GgufLoraModelBuilder {
7    gguf_model: GgufModelBuilder,
8    lora_model_id: String,
9    ordering: Ordering,
10}
11
12impl GgufLoraModelBuilder {
13    pub fn from_gguf_model_builder(
14        gguf_model: GgufModelBuilder,
15        lora_model_id: impl ToString,
16        ordering: Ordering,
17    ) -> Self {
18        Self {
19            gguf_model,
20            lora_model_id: lora_model_id.to_string(),
21            ordering,
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = GGUFSpecificConfig {
27            prompt_chunksize: self.gguf_model.prompt_chunksize,
28            topology: self.gguf_model.topology,
29        };
30
31        if self.gguf_model.with_logging {
32            initialize_logging();
33        }
34
35        let loader = GGUFLoaderBuilder::new(
36            self.gguf_model.chat_template,
37            self.gguf_model.tok_model_id,
38            self.gguf_model.model_id,
39            self.gguf_model.files,
40            config,
41            self.gguf_model.no_kv_cache,
42            self.gguf_model.jinja_explicit,
43        )
44        .with_lora(self.lora_model_id, self.ordering)
45        .build();
46
47        // Load, into a Pipeline
48        let pipeline = loader.load_model_from_hf(
49            self.gguf_model.hf_revision,
50            self.gguf_model.token_source,
51            &ModelDType::Auto,
52            &best_device(self.gguf_model.force_cpu)?,
53            !self.gguf_model.with_logging,
54            self.gguf_model
55                .device_mapping
56                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
57            None,
58            self.gguf_model.paged_attn_cfg,
59        )?;
60
61        let scheduler_method = match self.gguf_model.paged_attn_cfg {
62            Some(_) => {
63                let config = pipeline
64                    .lock()
65                    .await
66                    .get_metadata()
67                    .cache_config
68                    .as_ref()
69                    .unwrap()
70                    .clone();
71
72                SchedulerConfig::PagedAttentionMeta {
73                    max_num_seqs: self.gguf_model.max_num_seqs,
74                    config,
75                }
76            }
77            None => SchedulerConfig::DefaultScheduler {
78                method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
79            },
80        };
81
82        let mut runner = MistralRsBuilder::new(
83            pipeline,
84            scheduler_method,
85            self.gguf_model.throughput_logging,
86            self.gguf_model.search_bert_model,
87        )
88        .with_no_kv_cache(self.gguf_model.no_kv_cache)
89        .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
90
91        if let Some(n) = self.gguf_model.prefix_cache_n {
92            runner = runner.with_prefix_cache_n(n)
93        }
94
95        Ok(Model::new(runner.build()))
96    }
97}