mistralrs/
gguf_lora_model.rs1use mistralrs_core::*;
2
3use crate::{best_device, GgufModelBuilder, Model};
4
5pub struct GgufLoraModelBuilder {
7 gguf_model: GgufModelBuilder,
8 lora_model_id: String,
9 ordering: Ordering,
10}
11
12impl GgufLoraModelBuilder {
13 pub fn from_gguf_model_builder(
14 gguf_model: GgufModelBuilder,
15 lora_model_id: impl ToString,
16 ordering: Ordering,
17 ) -> Self {
18 Self {
19 gguf_model,
20 lora_model_id: lora_model_id.to_string(),
21 ordering,
22 }
23 }
24
25 pub async fn build(self) -> anyhow::Result<Model> {
26 let config = GGUFSpecificConfig {
27 topology: self.gguf_model.topology,
28 };
29
30 if self.gguf_model.with_logging {
31 initialize_logging();
32 }
33
34 let loader = GGUFLoaderBuilder::new(
35 self.gguf_model.chat_template,
36 self.gguf_model.tok_model_id,
37 self.gguf_model.model_id,
38 self.gguf_model.files,
39 config,
40 self.gguf_model.no_kv_cache,
41 self.gguf_model.jinja_explicit,
42 )
43 .with_lora(self.lora_model_id, self.ordering)
44 .build();
45
46 let pipeline = loader.load_model_from_hf(
48 self.gguf_model.hf_revision,
49 self.gguf_model.token_source,
50 &ModelDType::Auto,
51 &best_device(self.gguf_model.force_cpu)?,
52 !self.gguf_model.with_logging,
53 self.gguf_model
54 .device_mapping
55 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
56 None,
57 self.gguf_model.paged_attn_cfg,
58 )?;
59
60 let scheduler_method = match self.gguf_model.paged_attn_cfg {
61 Some(_) => {
62 let config = pipeline
63 .lock()
64 .await
65 .get_metadata()
66 .cache_config
67 .as_ref()
68 .unwrap()
69 .clone();
70
71 SchedulerConfig::PagedAttentionMeta {
72 max_num_seqs: self.gguf_model.max_num_seqs,
73 config,
74 }
75 }
76 None => SchedulerConfig::DefaultScheduler {
77 method: DefaultSchedulerMethod::Fixed(self.gguf_model.max_num_seqs.try_into()?),
78 },
79 };
80
81 let mut runner = MistralRsBuilder::new(
82 pipeline,
83 scheduler_method,
84 self.gguf_model.throughput_logging,
85 self.gguf_model.search_bert_model,
86 );
87 if let Some(cb) = self.gguf_model.search_callback.clone() {
88 runner = runner.with_search_callback(cb);
89 }
90 for (name, cb) in &self.gguf_model.tool_callbacks {
91 runner = runner.with_tool_callback(name.clone(), cb.clone());
92 }
93 runner = runner
94 .with_no_kv_cache(self.gguf_model.no_kv_cache)
95 .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
96
97 if let Some(n) = self.gguf_model.prefix_cache_n {
98 runner = runner.with_prefix_cache_n(n)
99 }
100
101 Ok(Model::new(runner.build().await))
102 }
103}