1use candle_core::Device;
2use mistralrs_core::*;
3use std::num::NonZeroUsize;
4
5use crate::{best_device, Model};
6
7pub struct GgufModelBuilder {
9 pub(crate) model_id: String,
11 pub(crate) files: Vec<String>,
12 pub(crate) tok_model_id: Option<String>,
13 pub(crate) token_source: TokenSource,
14 pub(crate) hf_revision: Option<String>,
15 pub(crate) chat_template: Option<String>,
16 pub(crate) jinja_explicit: Option<String>,
17 pub(crate) tokenizer_json: Option<String>,
18 pub(crate) device_mapping: Option<DeviceMapSetting>,
19 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
20 pub(crate) device: Option<Device>,
21
22 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
24 pub(crate) force_cpu: bool,
25 pub(crate) topology: Option<Topology>,
26 pub(crate) throughput_logging: bool,
27
28 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
30 pub(crate) max_num_seqs: usize,
31 pub(crate) no_kv_cache: bool,
32 pub(crate) with_logging: bool,
33 pub(crate) prefix_cache_n: Option<usize>,
34}
35
36impl GgufModelBuilder {
37 pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
44 Self {
45 model_id: model_id.to_string(),
46 files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
47 prompt_chunksize: None,
48 chat_template: None,
49 tokenizer_json: None,
50 force_cpu: false,
51 token_source: TokenSource::CacheToken,
52 hf_revision: None,
53 paged_attn_cfg: None,
54 max_num_seqs: 32,
55 no_kv_cache: false,
56 prefix_cache_n: Some(16),
57 with_logging: false,
58 topology: None,
59 tok_model_id: None,
60 device_mapping: None,
61 jinja_explicit: None,
62 throughput_logging: false,
63 search_bert_model: None,
64 device: None,
65 }
66 }
67
68 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
70 self.search_bert_model = Some(search_bert_model);
71 self
72 }
73
74 pub fn with_throughput_logging(mut self) -> Self {
76 self.throughput_logging = true;
77 self
78 }
79
80 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
82 self.jinja_explicit = Some(jinja_explicit);
83 self
84 }
85
86 pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
88 self.tok_model_id = Some(tok_model_id.to_string());
89 self
90 }
91
92 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
94 self.prompt_chunksize = Some(prompt_chunksize);
95 self
96 }
97
98 pub fn with_topology(mut self, topology: Topology) -> Self {
100 self.topology = Some(topology);
101 self
102 }
103
104 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
106 self.chat_template = Some(chat_template.to_string());
107 self
108 }
109
110 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
112 self.tokenizer_json = Some(tokenizer_json.to_string());
113 self
114 }
115
116 pub fn with_force_cpu(mut self) -> Self {
118 self.force_cpu = true;
119 self
120 }
121
122 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
124 self.token_source = token_source;
125 self
126 }
127
128 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
130 self.hf_revision = Some(revision.to_string());
131 self
132 }
133
134 pub fn with_paged_attn(
141 mut self,
142 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
143 ) -> anyhow::Result<Self> {
144 if paged_attn_supported() {
145 self.paged_attn_cfg = Some(paged_attn_cfg()?);
146 } else {
147 self.paged_attn_cfg = None;
148 }
149 Ok(self)
150 }
151
152 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
154 self.max_num_seqs = max_num_seqs;
155 self
156 }
157
158 pub fn with_no_kv_cache(mut self) -> Self {
160 self.no_kv_cache = true;
161 self
162 }
163
164 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
166 self.prefix_cache_n = n_seqs;
167 self
168 }
169
170 pub fn with_logging(mut self) -> Self {
172 self.with_logging = true;
173 self
174 }
175
176 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
178 self.device_mapping = Some(device_mapping);
179 self
180 }
181
182 pub fn with_device(mut self, device: Device) -> Self {
184 self.device = Some(device);
185 self
186 }
187
188 pub async fn build(self) -> anyhow::Result<Model> {
189 let config = GGUFSpecificConfig {
190 prompt_chunksize: self.prompt_chunksize,
191 topology: self.topology,
192 };
193
194 if self.with_logging {
195 initialize_logging();
196 }
197
198 let loader = GGUFLoaderBuilder::new(
199 self.chat_template,
200 self.tok_model_id,
201 self.model_id,
202 self.files,
203 config,
204 self.no_kv_cache,
205 self.jinja_explicit,
206 )
207 .build();
208
209 let pipeline = loader.load_model_from_hf(
211 self.hf_revision,
212 self.token_source,
213 &ModelDType::Auto,
214 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
215 !self.with_logging,
216 self.device_mapping
217 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
218 None,
219 self.paged_attn_cfg,
220 )?;
221
222 let scheduler_method = match self.paged_attn_cfg {
223 Some(_) => {
224 let config = pipeline
225 .lock()
226 .await
227 .get_metadata()
228 .cache_config
229 .as_ref()
230 .unwrap()
231 .clone();
232
233 SchedulerConfig::PagedAttentionMeta {
234 max_num_seqs: self.max_num_seqs,
235 config,
236 }
237 }
238 None => SchedulerConfig::DefaultScheduler {
239 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
240 },
241 };
242
243 let mut runner = MistralRsBuilder::new(
244 pipeline,
245 scheduler_method,
246 self.throughput_logging,
247 self.search_bert_model,
248 )
249 .with_no_kv_cache(self.no_kv_cache)
250 .with_no_prefix_cache(self.prefix_cache_n.is_none());
251
252 if let Some(n) = self.prefix_cache_n {
253 runner = runner.with_prefix_cache_n(n)
254 }
255
256 Ok(Model::new(runner.build()))
257 }
258}