mistralrs/
gguf_lora_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5/// Wrapper of [`GgufModelBuilder`] for LoRA models.
6pub struct GgufLoraModelBuilder {
7    gguf_model: GgufModelBuilder,
8    lora_model_id: String,
9    ordering: Ordering,
10}
11
12impl GgufLoraModelBuilder {
13    pub fn from_gguf_model_builder(
14        gguf_model: GgufModelBuilder,
15        lora_model_id: impl ToString,
16        ordering: Ordering,
17    ) -> Self {
18        Self {
19            gguf_model,
20            lora_model_id: lora_model_id.to_string(),
21            ordering,
22        }
23    }
24
25    pub async fn build(self) -> anyhow::Result<Model> {
26        let config = GGUFSpecificConfig {
27            topology: self.gguf_model.topology,
28        };
29
30        if self.gguf_model.with_logging {
31            initialize_logging();
32        }
33
34        let loader = GGUFLoaderBuilder::new(
35            self.gguf_model.chat_template,
36            self.gguf_model.tok_model_id,
37            self.gguf_model.model_id,
38            self.gguf_model.files,
39            config,
40            self.gguf_model.no_kv_cache,
41            self.gguf_model.jinja_explicit,
42        )
43        .with_lora(self.lora_model_id, self.ordering)
44        .build();
45
46        // Load, into a Pipeline
47        let pipeline = loader.load_model_from_hf(
48            self.gguf_model.hf_revision,
49            self.gguf_model.token_source,
50            &ModelDType::Auto,
51            &best_device(self.gguf_model.force_cpu)?,
52            !self.gguf_model.with_logging,
53            self.gguf_model
54                .device_mapping
55                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
56            None,
57            self.gguf_model.paged_attn_cfg,
58        )?;
59
60        let scheduler_method = match self.gguf_model.paged_attn_cfg {
61            Some(_) => {
62                let config = pipeline
63                    .lock()
64                    .await
65                    .get_metadata()
66                    .cache_config
67                    .as_ref()
68                    .unwrap()
69                    .clone();
70
71                SchedulerConfig::PagedAttentionMeta {
72                    max_num_seqs: self.gguf_model.max_num_seqs,
73                    config,
74                }
75            }
76            None => SchedulerConfig::DefaultScheduler {
77                method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
78            },
79        };
80
81        let mut runner = MistralRsBuilder::new(
82            pipeline,
83            scheduler_method,
84            self.gguf_model.throughput_logging,
85            self.gguf_model.search_bert_model,
86        );
87        if let Some(cb) = self.gguf_model.search_callback.clone() {
88            runner = runner.with_search_callback(cb);
89        }
90        for (name, cb) in &self.gguf_model.tool_callbacks {
91            runner = runner.with_tool_callback(name.clone(), cb.clone());
92        }
93        runner = runner
94            .with_no_kv_cache(self.gguf_model.no_kv_cache)
95            .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
96
97        if let Some(n) = self.gguf_model.prefix_cache_n {
98            runner = runner.with_prefix_cache_n(n)
99        }
100
101        Ok(Model::new(runner.build().await))
102    }
103}