1use mistralrs_core::*;
2use std::num::NonZeroUsize;
3
4use crate::{best_device, Model};
5
6pub struct GgufModelBuilder {
8 pub(crate) model_id: String,
10 pub(crate) files: Vec<String>,
11 pub(crate) tok_model_id: Option<String>,
12 pub(crate) token_source: TokenSource,
13 pub(crate) hf_revision: Option<String>,
14 pub(crate) chat_template: Option<String>,
15 pub(crate) jinja_explicit: Option<String>,
16 pub(crate) tokenizer_json: Option<String>,
17 pub(crate) device_mapping: Option<DeviceMapSetting>,
18 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
19
20 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
22 pub(crate) force_cpu: bool,
23 pub(crate) topology: Option<Topology>,
24 pub(crate) throughput_logging: bool,
25
26 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
28 pub(crate) max_num_seqs: usize,
29 pub(crate) no_kv_cache: bool,
30 pub(crate) with_logging: bool,
31 pub(crate) prefix_cache_n: Option<usize>,
32}
33
34impl GgufModelBuilder {
35 pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
42 Self {
43 model_id: model_id.to_string(),
44 files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
45 prompt_chunksize: None,
46 chat_template: None,
47 tokenizer_json: None,
48 force_cpu: false,
49 token_source: TokenSource::CacheToken,
50 hf_revision: None,
51 paged_attn_cfg: None,
52 max_num_seqs: 32,
53 no_kv_cache: false,
54 prefix_cache_n: Some(16),
55 with_logging: false,
56 topology: None,
57 tok_model_id: None,
58 device_mapping: None,
59 jinja_explicit: None,
60 throughput_logging: false,
61 search_bert_model: None,
62 }
63 }
64
65 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
67 self.search_bert_model = Some(search_bert_model);
68 self
69 }
70
71 pub fn with_throughput_logging(mut self) -> Self {
73 self.throughput_logging = true;
74 self
75 }
76
77 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
79 self.jinja_explicit = Some(jinja_explicit);
80 self
81 }
82
83 pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
85 self.tok_model_id = Some(tok_model_id.to_string());
86 self
87 }
88
89 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
91 self.prompt_chunksize = Some(prompt_chunksize);
92 self
93 }
94
95 pub fn with_topology(mut self, topology: Topology) -> Self {
97 self.topology = Some(topology);
98 self
99 }
100
101 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
103 self.chat_template = Some(chat_template.to_string());
104 self
105 }
106
107 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
109 self.tokenizer_json = Some(tokenizer_json.to_string());
110 self
111 }
112
113 pub fn with_force_cpu(mut self) -> Self {
115 self.force_cpu = true;
116 self
117 }
118
119 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
121 self.token_source = token_source;
122 self
123 }
124
125 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
127 self.hf_revision = Some(revision.to_string());
128 self
129 }
130
131 pub fn with_paged_attn(
138 mut self,
139 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
140 ) -> anyhow::Result<Self> {
141 if paged_attn_supported() {
142 self.paged_attn_cfg = Some(paged_attn_cfg()?);
143 } else {
144 self.paged_attn_cfg = None;
145 }
146 Ok(self)
147 }
148
149 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
151 self.max_num_seqs = max_num_seqs;
152 self
153 }
154
155 pub fn with_no_kv_cache(mut self) -> Self {
157 self.no_kv_cache = true;
158 self
159 }
160
161 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
163 self.prefix_cache_n = n_seqs;
164 self
165 }
166
167 pub fn with_logging(mut self) -> Self {
169 self.with_logging = true;
170 self
171 }
172
173 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
175 self.device_mapping = Some(device_mapping);
176 self
177 }
178
179 pub async fn build(self) -> anyhow::Result<Model> {
180 let config = GGUFSpecificConfig {
181 prompt_chunksize: self.prompt_chunksize,
182 topology: self.topology,
183 };
184
185 if self.with_logging {
186 initialize_logging();
187 }
188
189 let loader = GGUFLoaderBuilder::new(
190 self.chat_template,
191 self.tok_model_id,
192 self.model_id,
193 self.files,
194 config,
195 self.no_kv_cache,
196 self.jinja_explicit,
197 )
198 .build();
199
200 let pipeline = loader.load_model_from_hf(
202 self.hf_revision,
203 self.token_source,
204 &ModelDType::Auto,
205 &best_device(self.force_cpu)?,
206 !self.with_logging,
207 self.device_mapping
208 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
209 None,
210 self.paged_attn_cfg,
211 )?;
212
213 let scheduler_method = match self.paged_attn_cfg {
214 Some(_) => {
215 let config = pipeline
216 .lock()
217 .await
218 .get_metadata()
219 .cache_config
220 .as_ref()
221 .unwrap()
222 .clone();
223
224 SchedulerConfig::PagedAttentionMeta {
225 max_num_seqs: self.max_num_seqs,
226 config,
227 }
228 }
229 None => SchedulerConfig::DefaultScheduler {
230 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
231 },
232 };
233
234 let mut runner = MistralRsBuilder::new(
235 pipeline,
236 scheduler_method,
237 self.throughput_logging,
238 self.search_bert_model,
239 )
240 .with_no_kv_cache(self.no_kv_cache)
241 .with_no_prefix_cache(self.prefix_cache_n.is_none());
242
243 if let Some(n) = self.prefix_cache_n {
244 runner = runner.with_prefix_cache_n(n)
245 }
246
247 Ok(Model::new(runner.build()))
248 }
249}