1use std::{
2 fs::{self, File},
3 path::PathBuf,
4 str::FromStr,
5};
6
7use mistralrs_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10 get_toml_selected_model_dtype,
11 pipeline::{
12 AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13 GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
14 VisionLoaderBuilder, VisionSpecificConfig,
15 },
16 toml_selector::get_toml_selected_model_device_map_params,
17 AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig, Loader, ModelDType,
18 ModelSelected, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER,
19 UQFF_MULTI_FILE_DELIMITER,
20};
21
22pub struct LoaderBuilder {
24 model: ModelSelected,
25 no_kv_cache: bool,
26 chat_template: Option<String>,
27 jinja_explicit: Option<String>,
28}
29
30impl LoaderBuilder {
31 pub fn new(model: ModelSelected) -> Self {
32 Self {
33 model,
34 no_kv_cache: false,
35 chat_template: None,
36 jinja_explicit: None,
37 }
38 }
39
40 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41 self.no_kv_cache = no_kv_cache;
42 self
43 }
44 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45 self.chat_template = chat_template;
46 self
47 }
48 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49 self.jinja_explicit = jinja_explicit;
50 self
51 }
52
53 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
54 loader_from_model_selected(self)
55 }
56}
57
58pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
59 match model {
60 ModelSelected::Plain { .. }
61 | ModelSelected::Run { .. }
62 | ModelSelected::Lora { .. }
63 | ModelSelected::GGUF { .. }
64 | ModelSelected::LoraGGUF { .. }
65 | ModelSelected::GGML { .. }
66 | ModelSelected::LoraGGML { .. }
67 | ModelSelected::Toml { .. }
68 | ModelSelected::VisionPlain { .. }
69 | ModelSelected::DiffusionPlain { .. }
70 | ModelSelected::Speech { .. }
71 | ModelSelected::Embedding { .. } => None,
72 ModelSelected::XLora {
73 tgt_non_granular_index,
74 ..
75 }
76 | ModelSelected::XLoraGGUF {
77 tgt_non_granular_index,
78 ..
79 }
80 | ModelSelected::XLoraGGML {
81 tgt_non_granular_index,
82 ..
83 } => *tgt_non_granular_index,
84 ModelSelected::MultiModel { .. } => {
85 panic!("MultiModel variant should not be used in model loading functions")
86 }
87 }
88}
89
90pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
91 match model {
92 ModelSelected::Plain { dtype, .. }
93 | ModelSelected::Lora { dtype, .. }
94 | ModelSelected::XLora { dtype, .. }
95 | ModelSelected::VisionPlain { dtype, .. }
96 | ModelSelected::DiffusionPlain { dtype, .. }
97 | ModelSelected::GGML { dtype, .. }
98 | ModelSelected::GGUF { dtype, .. }
99 | ModelSelected::XLoraGGUF { dtype, .. }
100 | ModelSelected::XLoraGGML { dtype, .. }
101 | ModelSelected::LoraGGUF { dtype, .. }
102 | ModelSelected::LoraGGML { dtype, .. }
103 | ModelSelected::Run { dtype, .. }
104 | ModelSelected::Speech { dtype, .. }
105 | ModelSelected::Embedding { dtype, .. } => Ok(*dtype),
106 ModelSelected::Toml { file } => {
107 let selector: TomlSelector = toml::from_str(
108 &fs::read_to_string(file.clone())
109 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
110 )?;
111 Ok(get_toml_selected_model_dtype(&selector))
112 }
113 ModelSelected::MultiModel { .. } => {
114 anyhow::bail!("MultiModel variant should not be used in model loading functions")
115 }
116 }
117}
118
119pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
120 match model {
121 ModelSelected::Plain {
122 max_seq_len,
123 max_batch_size,
124 ..
125 }
126 | ModelSelected::Lora {
127 max_seq_len,
128 max_batch_size,
129 ..
130 }
131 | ModelSelected::XLora {
132 max_seq_len,
133 max_batch_size,
134 ..
135 }
136 | ModelSelected::GGML {
137 max_seq_len,
138 max_batch_size,
139 ..
140 }
141 | ModelSelected::GGUF {
142 max_seq_len,
143 max_batch_size,
144 ..
145 }
146 | ModelSelected::XLoraGGUF {
147 max_seq_len,
148 max_batch_size,
149 ..
150 }
151 | ModelSelected::XLoraGGML {
152 max_seq_len,
153 max_batch_size,
154 ..
155 }
156 | ModelSelected::LoraGGUF {
157 max_seq_len,
158 max_batch_size,
159 ..
160 }
161 | ModelSelected::LoraGGML {
162 max_seq_len,
163 max_batch_size,
164 ..
165 } => Ok(AutoDeviceMapParams::Text {
166 max_seq_len: *max_seq_len,
167 max_batch_size: *max_batch_size,
168 }),
169 ModelSelected::Run {
170 max_seq_len,
171 max_batch_size,
172 max_image_length,
173 max_num_images,
174 ..
175 } => {
176 if max_num_images.is_some() || max_image_length.is_some() {
177 let max_image_length =
178 max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
179 Ok(AutoDeviceMapParams::Vision {
180 max_seq_len: *max_seq_len,
181 max_batch_size: *max_batch_size,
182 max_image_shape: (max_image_length, max_image_length),
183 max_num_images: max_num_images
184 .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
185 })
186 } else {
187 Ok(AutoDeviceMapParams::Text {
188 max_seq_len: *max_seq_len,
189 max_batch_size: *max_batch_size,
190 })
191 }
192 }
193 ModelSelected::VisionPlain {
194 max_seq_len,
195 max_batch_size,
196 max_image_length,
197 max_num_images,
198 ..
199 } => Ok(AutoDeviceMapParams::Vision {
200 max_seq_len: *max_seq_len,
201 max_batch_size: *max_batch_size,
202 max_image_shape: (*max_image_length, *max_image_length),
203 max_num_images: *max_num_images,
204 }),
205 ModelSelected::DiffusionPlain { .. }
206 | ModelSelected::Speech { .. }
207 | ModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
208 ModelSelected::Toml { file } => {
209 let selector: TomlSelector = toml::from_str(
210 &fs::read_to_string(file.clone())
211 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
212 )?;
213 get_toml_selected_model_device_map_params(&selector)
214 }
215 ModelSelected::MultiModel { .. } => {
216 anyhow::bail!("MultiModel variant should not be used in model loading functions")
217 }
218 }
219}
220
221fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
222 let loader: Box<dyn Loader> = match args.model {
223 ModelSelected::Toml { file } => {
224 let selector: TomlSelector = toml::from_str(
225 &fs::read_to_string(file.clone())
226 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
227 )?;
228 let args = TomlLoaderArgs {
229 chat_template: args.chat_template,
230 no_kv_cache: args.no_kv_cache,
231 jinja_explicit: args.jinja_explicit,
232 };
233 (selector, args).try_into()?
234 }
235 ModelSelected::Plain {
236 model_id,
237 tokenizer_json,
238 arch,
239 dtype: _,
240 topology,
241 organization,
242 write_uqff,
243 from_uqff,
244 imatrix,
245 calibration_file,
246 max_seq_len: _,
247 max_batch_size: _,
248 hf_cache_path,
249 matformer_config_path,
250 matformer_slice_name,
251 } => NormalLoaderBuilder::new(
252 NormalSpecificConfig {
253 topology: Topology::from_option_path(topology)?,
254 organization: organization.unwrap_or_default(),
255 write_uqff,
256 from_uqff: from_uqff.map(|x| {
257 x.split(UQFF_MULTI_FILE_DELIMITER)
258 .map(PathBuf::from_str)
259 .map(|x| x.unwrap())
260 .collect::<Vec<_>>()
261 }),
262 imatrix,
263 calibration_file,
264 hf_cache_path,
265 matformer_config_path,
266 matformer_slice_name,
267 },
268 args.chat_template,
269 tokenizer_json,
270 Some(model_id),
271 args.no_kv_cache,
272 args.jinja_explicit,
273 )
274 .build(arch)?,
275 ModelSelected::Run {
276 model_id,
277 tokenizer_json,
278 dtype: _,
279 topology,
280 organization,
281 write_uqff,
282 from_uqff,
283 imatrix,
284 calibration_file,
285 max_edge,
286 max_seq_len: _,
287 max_batch_size: _,
288 max_num_images: _,
289 max_image_length: _,
290 hf_cache_path,
291 matformer_config_path,
292 matformer_slice_name,
293 } => {
294 let builder = AutoLoaderBuilder::new(
295 NormalSpecificConfig {
296 topology: Topology::from_option_path(topology.clone())?,
297 organization: organization.unwrap_or_default(),
298 write_uqff: write_uqff.clone(),
299 from_uqff: from_uqff.clone().map(|x| {
300 x.split(UQFF_MULTI_FILE_DELIMITER)
301 .map(PathBuf::from_str)
302 .map(|x| x.unwrap())
303 .collect::<Vec<_>>()
304 }),
305 imatrix: imatrix.clone(),
306 calibration_file: calibration_file.clone(),
307 hf_cache_path: hf_cache_path.clone(),
308 matformer_config_path: matformer_config_path.clone(),
309 matformer_slice_name: matformer_slice_name.clone(),
310 },
311 VisionSpecificConfig {
312 topology: Topology::from_option_path(topology.clone())?,
313 write_uqff: write_uqff.clone(),
314 from_uqff: from_uqff.clone().map(|x| {
315 x.split(UQFF_MULTI_FILE_DELIMITER)
316 .map(PathBuf::from_str)
317 .map(|x| x.unwrap())
318 .collect::<Vec<_>>()
319 }),
320 max_edge,
321 calibration_file,
322 imatrix,
323 hf_cache_path: hf_cache_path.clone(),
324 matformer_config_path,
325 matformer_slice_name,
326 },
327 EmbeddingSpecificConfig {
328 topology: Topology::from_option_path(topology)?,
329 write_uqff,
330 from_uqff: from_uqff.map(|x| {
331 x.split(UQFF_MULTI_FILE_DELIMITER)
332 .map(PathBuf::from_str)
333 .map(|x| x.unwrap())
334 .collect::<Vec<_>>()
335 }),
336 hf_cache_path: hf_cache_path.clone(),
337 },
338 args.chat_template,
339 tokenizer_json,
340 model_id,
341 args.no_kv_cache,
342 args.jinja_explicit,
343 );
344 let builder = if let Some(ref path) = hf_cache_path {
345 builder.hf_cache_path(path.clone())
346 } else {
347 builder
348 };
349 builder.build()
350 }
351 ModelSelected::VisionPlain {
352 model_id,
353 tokenizer_json,
354 arch,
355 dtype: _,
356 topology,
357 write_uqff,
358 from_uqff,
359 max_edge,
360 calibration_file,
361 max_seq_len: _,
362 max_batch_size: _,
363 max_num_images: _,
364 max_image_length: _,
365 hf_cache_path,
366 imatrix,
367 matformer_config_path,
368 matformer_slice_name,
369 } => VisionLoaderBuilder::new(
370 VisionSpecificConfig {
371 topology: Topology::from_option_path(topology)?,
372 write_uqff,
373 from_uqff: from_uqff.map(|x| {
374 x.split(UQFF_MULTI_FILE_DELIMITER)
375 .map(PathBuf::from_str)
376 .map(|x| x.unwrap())
377 .collect::<Vec<_>>()
378 }),
379 max_edge,
380 calibration_file,
381 imatrix,
382 hf_cache_path,
383 matformer_config_path,
384 matformer_slice_name,
385 },
386 args.chat_template,
387 tokenizer_json,
388 Some(model_id),
389 args.jinja_explicit,
390 )
391 .build(arch),
392 ModelSelected::DiffusionPlain {
393 model_id,
394 arch,
395 dtype: _,
396 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
397 ModelSelected::Speech {
398 model_id,
399 dac_model_id,
400 arch,
401 ..
402 } => Box::new(SpeechLoader {
403 model_id,
404 dac_model_id,
405 arch,
406 cfg: None,
407 }),
408 ModelSelected::XLora {
409 model_id,
410 xlora_model_id,
411 order,
412 tokenizer_json,
413 tgt_non_granular_index,
414 arch,
415 dtype: _,
416 topology,
417 write_uqff,
418 from_uqff,
419 max_seq_len: _,
420 max_batch_size: _,
421 hf_cache_path,
422 } => NormalLoaderBuilder::new(
423 NormalSpecificConfig {
424 topology: Topology::from_option_path(topology)?,
425 organization: Default::default(),
426 write_uqff,
427 from_uqff: from_uqff.map(|x| {
428 x.split(UQFF_MULTI_FILE_DELIMITER)
429 .map(PathBuf::from_str)
430 .map(|x| x.unwrap())
431 .collect::<Vec<_>>()
432 }),
433 imatrix: None,
434 calibration_file: None,
435 hf_cache_path,
436 matformer_config_path: None,
437 matformer_slice_name: None,
438 },
439 args.chat_template,
440 tokenizer_json,
441 model_id,
442 args.no_kv_cache,
443 args.jinja_explicit,
444 )
445 .with_xlora(
446 xlora_model_id,
447 serde_json::from_reader(
448 File::open(order.clone())
449 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
450 )?,
451 args.no_kv_cache,
452 tgt_non_granular_index,
453 )
454 .build(arch)?,
455 ModelSelected::Lora {
456 model_id,
457 tokenizer_json,
458 adapter_model_id,
459 arch,
460 dtype: _,
461 topology,
462 write_uqff,
463 from_uqff,
464 max_seq_len: _,
465 max_batch_size: _,
466 hf_cache_path,
467 } => NormalLoaderBuilder::new(
468 NormalSpecificConfig {
469 topology: Topology::from_option_path(topology)?,
470 organization: Default::default(),
471 write_uqff,
472 from_uqff: from_uqff.map(|x| {
473 x.split(UQFF_MULTI_FILE_DELIMITER)
474 .map(PathBuf::from_str)
475 .map(|x| x.unwrap())
476 .collect::<Vec<_>>()
477 }),
478 imatrix: None,
479 calibration_file: None,
480 hf_cache_path,
481 matformer_config_path: None,
482 matformer_slice_name: None,
483 },
484 args.chat_template,
485 tokenizer_json,
486 model_id,
487 args.no_kv_cache,
488 args.jinja_explicit,
489 )
490 .with_lora(
491 adapter_model_id
492 .split(MULTI_LORA_DELIMITER)
493 .map(ToString::to_string)
494 .collect(),
495 )
496 .build(arch)?,
497 ModelSelected::GGUF {
498 tok_model_id,
499 quantized_model_id,
500 quantized_filename,
501 topology,
502 ..
503 } => GGUFLoaderBuilder::new(
504 args.chat_template,
505 tok_model_id,
506 quantized_model_id,
507 quantized_filename
508 .split(GGUF_MULTI_FILE_DELIMITER)
509 .map(ToOwned::to_owned)
510 .collect::<Vec<_>>(),
511 GGUFSpecificConfig {
512 topology: Topology::from_option_path(topology)?,
513 },
514 args.no_kv_cache,
515 args.jinja_explicit,
516 )
517 .build(),
518 ModelSelected::XLoraGGUF {
519 tok_model_id,
520 quantized_model_id,
521 quantized_filename,
522 xlora_model_id,
523 order,
524 tgt_non_granular_index,
525 topology,
526 ..
527 } => GGUFLoaderBuilder::new(
528 args.chat_template,
529 tok_model_id,
530 quantized_model_id,
531 quantized_filename
532 .split(GGUF_MULTI_FILE_DELIMITER)
533 .map(ToOwned::to_owned)
534 .collect::<Vec<_>>(),
535 GGUFSpecificConfig {
536 topology: Topology::from_option_path(topology)?,
537 },
538 args.no_kv_cache,
539 args.jinja_explicit,
540 )
541 .with_xlora(
542 xlora_model_id,
543 serde_json::from_reader(
544 File::open(order.clone())
545 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
546 )?,
547 args.no_kv_cache,
548 tgt_non_granular_index,
549 )
550 .build(),
551 ModelSelected::LoraGGUF {
552 tok_model_id,
553 quantized_model_id,
554 quantized_filename,
555 adapters_model_id,
556 order,
557 topology,
558 ..
559 } => GGUFLoaderBuilder::new(
560 args.chat_template,
561 tok_model_id,
562 quantized_model_id,
563 quantized_filename
564 .split(GGUF_MULTI_FILE_DELIMITER)
565 .map(ToOwned::to_owned)
566 .collect::<Vec<_>>(),
567 GGUFSpecificConfig {
568 topology: Topology::from_option_path(topology)?,
569 },
570 args.no_kv_cache,
571 args.jinja_explicit,
572 )
573 .with_lora(
574 adapters_model_id,
575 serde_json::from_reader(
576 File::open(order.clone())
577 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
578 )?,
579 )
580 .build(),
581 ModelSelected::GGML {
582 tok_model_id,
583 tokenizer_json,
584 quantized_model_id,
585 quantized_filename,
586 gqa,
587 topology,
588 ..
589 } => GGMLLoaderBuilder::new(
590 GGMLSpecificConfig {
591 gqa,
592 topology: Topology::from_option_path(topology)?,
593 },
594 args.chat_template,
595 tokenizer_json,
596 Some(tok_model_id),
597 quantized_model_id,
598 quantized_filename,
599 args.no_kv_cache,
600 args.jinja_explicit,
601 )
602 .build(),
603 ModelSelected::XLoraGGML {
604 tok_model_id,
605 tokenizer_json,
606 quantized_model_id,
607 quantized_filename,
608 xlora_model_id,
609 order,
610 tgt_non_granular_index,
611 gqa,
612 topology,
613 ..
614 } => GGMLLoaderBuilder::new(
615 GGMLSpecificConfig {
616 gqa,
617 topology: Topology::from_option_path(topology)?,
618 },
619 args.chat_template,
620 tokenizer_json,
621 tok_model_id,
622 quantized_model_id,
623 quantized_filename,
624 args.no_kv_cache,
625 args.jinja_explicit,
626 )
627 .with_xlora(
628 xlora_model_id,
629 serde_json::from_reader(
630 File::open(order.clone())
631 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
632 )?,
633 args.no_kv_cache,
634 tgt_non_granular_index,
635 )
636 .build(),
637 ModelSelected::LoraGGML {
638 tok_model_id,
639 tokenizer_json,
640 quantized_model_id,
641 quantized_filename,
642 adapters_model_id,
643 order,
644 gqa,
645 topology,
646 ..
647 } => GGMLLoaderBuilder::new(
648 GGMLSpecificConfig {
649 gqa,
650 topology: Topology::from_option_path(topology)?,
651 },
652 args.chat_template,
653 tokenizer_json,
654 tok_model_id,
655 quantized_model_id,
656 quantized_filename,
657 args.no_kv_cache,
658 args.jinja_explicit,
659 )
660 .with_lora(
661 adapters_model_id,
662 serde_json::from_reader(
663 File::open(order.clone())
664 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
665 )?,
666 )
667 .build(),
668 ModelSelected::Embedding {
669 model_id,
670 tokenizer_json,
671 arch,
672 dtype: _,
673 topology,
674 write_uqff,
675 from_uqff,
676 hf_cache_path,
677 } => EmbeddingLoaderBuilder::new(
678 EmbeddingSpecificConfig {
679 topology: Topology::from_option_path(topology)?,
680 write_uqff,
681 from_uqff: from_uqff.map(|x| {
682 x.split(UQFF_MULTI_FILE_DELIMITER)
683 .map(PathBuf::from_str)
684 .map(|x| x.unwrap())
685 .collect::<Vec<_>>()
686 }),
687 hf_cache_path,
688 },
689 tokenizer_json,
690 Some(model_id),
691 )
692 .build(arch),
693 ModelSelected::MultiModel { .. } => {
694 anyhow::bail!("MultiModel variant should not be used in model loading functions")
695 }
696 };
697 Ok(loader)
698}