1use std::{
2 fs::{self, File},
3 num::NonZeroUsize,
4 path::PathBuf,
5 str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11 get_toml_selected_model_dtype,
12 pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
13 toml_selector::get_toml_selected_model_device_map_params,
14 AutoDeviceMapParams, DiffusionLoaderBuilder, GGUFSpecificConfig, Loader, ModelDType,
15 ModelSelected, NormalLoaderBuilder, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology,
16 VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
17 UQFF_MULTI_FILE_DELIMITER,
18};
19
20pub struct LoaderBuilder {
22 model: ModelSelected,
23 no_kv_cache: bool,
24 chat_template: Option<String>,
25 jinja_explicit: Option<String>,
26 prompt_chunksize: Option<NonZeroUsize>,
27}
28
29impl LoaderBuilder {
30 pub fn new(model: ModelSelected) -> Self {
31 Self {
32 model,
33 no_kv_cache: false,
34 chat_template: None,
35 prompt_chunksize: None,
36 jinja_explicit: None,
37 }
38 }
39
40 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41 self.no_kv_cache = no_kv_cache;
42 self
43 }
44 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45 self.chat_template = chat_template;
46 self
47 }
48 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49 self.jinja_explicit = jinja_explicit;
50 self
51 }
52 pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
53 self.prompt_chunksize = prompt_chunksize;
54 self
55 }
56
57 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
58 loader_from_model_selected(self)
59 }
60}
61
62pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
63 match model {
64 ModelSelected::Plain { .. }
65 | ModelSelected::Lora { .. }
66 | ModelSelected::GGUF { .. }
67 | ModelSelected::LoraGGUF { .. }
68 | ModelSelected::GGML { .. }
69 | ModelSelected::LoraGGML { .. }
70 | ModelSelected::Toml { .. }
71 | ModelSelected::VisionPlain { .. }
72 | ModelSelected::DiffusionPlain { .. }
73 | ModelSelected::Speech { .. } => None,
74 ModelSelected::XLora {
75 tgt_non_granular_index,
76 ..
77 }
78 | ModelSelected::XLoraGGUF {
79 tgt_non_granular_index,
80 ..
81 }
82 | ModelSelected::XLoraGGML {
83 tgt_non_granular_index,
84 ..
85 } => *tgt_non_granular_index,
86 }
87}
88
89pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
90 match model {
91 ModelSelected::Plain { dtype, .. }
92 | ModelSelected::Lora { dtype, .. }
93 | ModelSelected::XLora { dtype, .. }
94 | ModelSelected::VisionPlain { dtype, .. }
95 | ModelSelected::DiffusionPlain { dtype, .. }
96 | ModelSelected::GGML { dtype, .. }
97 | ModelSelected::GGUF { dtype, .. }
98 | ModelSelected::XLoraGGUF { dtype, .. }
99 | ModelSelected::XLoraGGML { dtype, .. }
100 | ModelSelected::LoraGGUF { dtype, .. }
101 | ModelSelected::LoraGGML { dtype, .. }
102 | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
103 ModelSelected::Toml { file } => {
104 let selector: TomlSelector = toml::from_str(
105 &fs::read_to_string(file.clone())
106 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
107 )?;
108 Ok(get_toml_selected_model_dtype(&selector))
109 }
110 }
111}
112
113pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
114 match model {
115 ModelSelected::Plain {
116 max_seq_len,
117 max_batch_size,
118 ..
119 }
120 | ModelSelected::Lora {
121 max_seq_len,
122 max_batch_size,
123 ..
124 }
125 | ModelSelected::XLora {
126 max_seq_len,
127 max_batch_size,
128 ..
129 }
130 | ModelSelected::GGML {
131 max_seq_len,
132 max_batch_size,
133 ..
134 }
135 | ModelSelected::GGUF {
136 max_seq_len,
137 max_batch_size,
138 ..
139 }
140 | ModelSelected::XLoraGGUF {
141 max_seq_len,
142 max_batch_size,
143 ..
144 }
145 | ModelSelected::XLoraGGML {
146 max_seq_len,
147 max_batch_size,
148 ..
149 }
150 | ModelSelected::LoraGGUF {
151 max_seq_len,
152 max_batch_size,
153 ..
154 }
155 | ModelSelected::LoraGGML {
156 max_seq_len,
157 max_batch_size,
158 ..
159 } => Ok(AutoDeviceMapParams::Text {
160 max_seq_len: *max_seq_len,
161 max_batch_size: *max_batch_size,
162 }),
163 ModelSelected::VisionPlain {
164 max_seq_len,
165 max_batch_size,
166 max_image_length,
167 max_num_images,
168 ..
169 } => Ok(AutoDeviceMapParams::Vision {
170 max_seq_len: *max_seq_len,
171 max_batch_size: *max_batch_size,
172 max_image_shape: (*max_image_length, *max_image_length),
173 max_num_images: *max_num_images,
174 }),
175 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
176 Ok(AutoDeviceMapParams::default_text())
177 }
178 ModelSelected::Toml { file } => {
179 let selector: TomlSelector = toml::from_str(
180 &fs::read_to_string(file.clone())
181 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
182 )?;
183 get_toml_selected_model_device_map_params(&selector)
184 }
185 }
186}
187
188fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
189 let loader: Box<dyn Loader> = match args.model {
190 ModelSelected::Toml { file } => {
191 let selector: TomlSelector = toml::from_str(
192 &fs::read_to_string(file.clone())
193 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
194 )?;
195 let args = TomlLoaderArgs {
196 chat_template: args.chat_template,
197 no_kv_cache: args.no_kv_cache,
198 prompt_chunksize: args.prompt_chunksize,
199 jinja_explicit: args.jinja_explicit,
200 };
201 (selector, args).try_into()?
202 }
203 ModelSelected::Plain {
204 model_id,
205 tokenizer_json,
206 arch,
207 dtype: _,
208 topology,
209 organization,
210 write_uqff,
211 from_uqff,
212 imatrix,
213 calibration_file,
214 max_seq_len: _,
215 max_batch_size: _,
216 hf_cache_path,
217 } => NormalLoaderBuilder::new(
218 NormalSpecificConfig {
219 prompt_chunksize: args.prompt_chunksize,
220 topology: Topology::from_option_path(topology)?,
221 organization: organization.unwrap_or_default(),
222 write_uqff,
223 from_uqff: from_uqff.map(|x| {
224 x.split(UQFF_MULTI_FILE_DELIMITER)
225 .map(PathBuf::from_str)
226 .map(|x| x.unwrap())
227 .collect::<Vec<_>>()
228 }),
229 imatrix,
230 calibration_file,
231 hf_cache_path,
232 },
233 args.chat_template,
234 tokenizer_json,
235 Some(model_id),
236 args.no_kv_cache,
237 args.jinja_explicit,
238 )
239 .build(arch)?,
240 ModelSelected::XLora {
241 model_id,
242 xlora_model_id,
243 order,
244 tokenizer_json,
245 tgt_non_granular_index,
246 arch,
247 dtype: _,
248 topology,
249 write_uqff,
250 from_uqff,
251 max_seq_len: _,
252 max_batch_size: _,
253 hf_cache_path,
254 } => NormalLoaderBuilder::new(
255 NormalSpecificConfig {
256 prompt_chunksize: args.prompt_chunksize,
257 topology: Topology::from_option_path(topology)?,
258 organization: Default::default(),
259 write_uqff,
260 from_uqff: from_uqff.map(|x| {
261 x.split(UQFF_MULTI_FILE_DELIMITER)
262 .map(PathBuf::from_str)
263 .map(|x| x.unwrap())
264 .collect::<Vec<_>>()
265 }),
266 imatrix: None,
267 calibration_file: None,
268 hf_cache_path,
269 },
270 args.chat_template,
271 tokenizer_json,
272 model_id,
273 args.no_kv_cache,
274 args.jinja_explicit,
275 )
276 .with_xlora(
277 xlora_model_id,
278 serde_json::from_reader(
279 File::open(order.clone())
280 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
281 )?,
282 args.no_kv_cache,
283 tgt_non_granular_index,
284 )
285 .build(arch)?,
286 ModelSelected::Lora {
287 model_id,
288 tokenizer_json,
289 adapter_model_id,
290 arch,
291 dtype: _,
292 topology,
293 write_uqff,
294 from_uqff,
295 max_seq_len: _,
296 max_batch_size: _,
297 hf_cache_path,
298 } => NormalLoaderBuilder::new(
299 NormalSpecificConfig {
300 prompt_chunksize: args.prompt_chunksize,
301 topology: Topology::from_option_path(topology)?,
302 organization: Default::default(),
303 write_uqff,
304 from_uqff: from_uqff.map(|x| {
305 x.split(UQFF_MULTI_FILE_DELIMITER)
306 .map(PathBuf::from_str)
307 .map(|x| x.unwrap())
308 .collect::<Vec<_>>()
309 }),
310 imatrix: None,
311 calibration_file: None,
312 hf_cache_path,
313 },
314 args.chat_template,
315 tokenizer_json,
316 model_id,
317 args.no_kv_cache,
318 args.jinja_explicit,
319 )
320 .with_lora(
321 adapter_model_id
322 .split(MULTI_LORA_DELIMITER)
323 .map(ToString::to_string)
324 .collect(),
325 )
326 .build(arch)?,
327 ModelSelected::GGUF {
328 tok_model_id,
329 quantized_model_id,
330 quantized_filename,
331 topology,
332 ..
333 } => GGUFLoaderBuilder::new(
334 args.chat_template,
335 tok_model_id,
336 quantized_model_id,
337 quantized_filename
338 .split(GGUF_MULTI_FILE_DELIMITER)
339 .map(ToOwned::to_owned)
340 .collect::<Vec<_>>(),
341 GGUFSpecificConfig {
342 prompt_chunksize: args.prompt_chunksize,
343 topology: Topology::from_option_path(topology)?,
344 },
345 args.no_kv_cache,
346 args.jinja_explicit,
347 )
348 .build(),
349 ModelSelected::XLoraGGUF {
350 tok_model_id,
351 quantized_model_id,
352 quantized_filename,
353 xlora_model_id,
354 order,
355 tgt_non_granular_index,
356 topology,
357 ..
358 } => GGUFLoaderBuilder::new(
359 args.chat_template,
360 tok_model_id,
361 quantized_model_id,
362 quantized_filename
363 .split(GGUF_MULTI_FILE_DELIMITER)
364 .map(ToOwned::to_owned)
365 .collect::<Vec<_>>(),
366 GGUFSpecificConfig {
367 prompt_chunksize: args.prompt_chunksize,
368 topology: Topology::from_option_path(topology)?,
369 },
370 args.no_kv_cache,
371 args.jinja_explicit,
372 )
373 .with_xlora(
374 xlora_model_id,
375 serde_json::from_reader(
376 File::open(order.clone())
377 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
378 )?,
379 args.no_kv_cache,
380 tgt_non_granular_index,
381 )
382 .build(),
383 ModelSelected::LoraGGUF {
384 tok_model_id,
385 quantized_model_id,
386 quantized_filename,
387 adapters_model_id,
388 order,
389 topology,
390 ..
391 } => GGUFLoaderBuilder::new(
392 args.chat_template,
393 tok_model_id,
394 quantized_model_id,
395 quantized_filename
396 .split(GGUF_MULTI_FILE_DELIMITER)
397 .map(ToOwned::to_owned)
398 .collect::<Vec<_>>(),
399 GGUFSpecificConfig {
400 prompt_chunksize: args.prompt_chunksize,
401 topology: Topology::from_option_path(topology)?,
402 },
403 args.no_kv_cache,
404 args.jinja_explicit,
405 )
406 .with_lora(
407 adapters_model_id,
408 serde_json::from_reader(
409 File::open(order.clone())
410 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
411 )?,
412 )
413 .build(),
414 ModelSelected::GGML {
415 tok_model_id,
416 tokenizer_json,
417 quantized_model_id,
418 quantized_filename,
419 gqa,
420 topology,
421 ..
422 } => GGMLLoaderBuilder::new(
423 GGMLSpecificConfig {
424 gqa,
425 prompt_chunksize: args.prompt_chunksize,
426 topology: Topology::from_option_path(topology)?,
427 },
428 args.chat_template,
429 tokenizer_json,
430 Some(tok_model_id),
431 quantized_model_id,
432 quantized_filename,
433 args.no_kv_cache,
434 args.jinja_explicit,
435 )
436 .build(),
437 ModelSelected::XLoraGGML {
438 tok_model_id,
439 tokenizer_json,
440 quantized_model_id,
441 quantized_filename,
442 xlora_model_id,
443 order,
444 tgt_non_granular_index,
445 gqa,
446 topology,
447 ..
448 } => GGMLLoaderBuilder::new(
449 GGMLSpecificConfig {
450 gqa,
451 prompt_chunksize: args.prompt_chunksize,
452 topology: Topology::from_option_path(topology)?,
453 },
454 args.chat_template,
455 tokenizer_json,
456 tok_model_id,
457 quantized_model_id,
458 quantized_filename,
459 args.no_kv_cache,
460 args.jinja_explicit,
461 )
462 .with_xlora(
463 xlora_model_id,
464 serde_json::from_reader(
465 File::open(order.clone())
466 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
467 )?,
468 args.no_kv_cache,
469 tgt_non_granular_index,
470 )
471 .build(),
472 ModelSelected::LoraGGML {
473 tok_model_id,
474 tokenizer_json,
475 quantized_model_id,
476 quantized_filename,
477 adapters_model_id,
478 order,
479 gqa,
480 topology,
481 ..
482 } => GGMLLoaderBuilder::new(
483 GGMLSpecificConfig {
484 gqa,
485 prompt_chunksize: args.prompt_chunksize,
486 topology: Topology::from_option_path(topology)?,
487 },
488 args.chat_template,
489 tokenizer_json,
490 tok_model_id,
491 quantized_model_id,
492 quantized_filename,
493 args.no_kv_cache,
494 args.jinja_explicit,
495 )
496 .with_lora(
497 adapters_model_id,
498 serde_json::from_reader(
499 File::open(order.clone())
500 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
501 )?,
502 )
503 .build(),
504 ModelSelected::VisionPlain {
505 model_id,
506 tokenizer_json,
507 arch,
508 dtype: _,
509 topology,
510 write_uqff,
511 from_uqff,
512 max_edge,
513 calibration_file,
514 max_seq_len: _,
515 max_batch_size: _,
516 max_num_images: _,
517 max_image_length: _,
518 hf_cache_path,
519 imatrix,
520 } => VisionLoaderBuilder::new(
521 VisionSpecificConfig {
522 prompt_chunksize: args.prompt_chunksize,
523 topology: Topology::from_option_path(topology)?,
524 write_uqff,
525 from_uqff: from_uqff.map(|x| {
526 x.split(UQFF_MULTI_FILE_DELIMITER)
527 .map(PathBuf::from_str)
528 .map(|x| x.unwrap())
529 .collect::<Vec<_>>()
530 }),
531 max_edge,
532 calibration_file,
533 imatrix,
534 hf_cache_path,
535 },
536 args.chat_template,
537 tokenizer_json,
538 Some(model_id),
539 args.jinja_explicit,
540 )
541 .build(arch),
542 ModelSelected::DiffusionPlain {
543 model_id,
544 arch,
545 dtype: _,
546 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
547 ModelSelected::Speech {
548 model_id,
549 dac_model_id,
550 arch,
551 ..
552 } => Box::new(SpeechLoader {
553 model_id,
554 dac_model_id,
555 arch,
556 cfg: None,
557 }),
558 };
559 Ok(loader)
560}