1use std::{fs::File, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7 amoe::AnyMoeConfig, pipeline::IsqOrganization, AnyMoeLoader, AutoDeviceMapParams,
8 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
9 ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
10 SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
11 GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
12};
13
14fn default_one() -> usize {
15 1
16}
17
18fn default_dtype() -> ModelDType {
19 ModelDType::Auto
20}
21
22fn default_empty_vec_usize() -> Vec<usize> {
23 Vec::new()
24}
25
26fn default_max_seq_len() -> usize {
27 AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
28}
29
30fn default_max_batch_size() -> usize {
31 AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
32}
33
34fn default_max_num_images() -> usize {
35 AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
36}
37
38fn default_max_image_length() -> usize {
39 AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44pub enum TomlModelSelected {
45 Plain {
47 model_id: String,
49
50 arch: Option<NormalLoaderType>,
52
53 #[serde(default = "default_dtype")]
55 dtype: ModelDType,
56
57 topology: Option<String>,
59
60 organization: Option<IsqOrganization>,
62
63 write_uqff: Option<PathBuf>,
65
66 from_uqff: Option<String>,
68
69 imatrix: Option<PathBuf>,
72
73 calibration_file: Option<PathBuf>,
76
77 #[serde(default = "default_max_seq_len")]
79 max_seq_len: usize,
80
81 #[serde(default = "default_max_batch_size")]
83 max_batch_size: usize,
84
85 hf_cache_path: Option<PathBuf>,
87 },
88
89 XLora {
91 model_id: Option<String>,
93
94 xlora_model_id: String,
96
97 order: String,
99
100 tgt_non_granular_index: Option<usize>,
103
104 arch: Option<NormalLoaderType>,
106
107 #[serde(default = "default_dtype")]
109 dtype: ModelDType,
110
111 topology: Option<String>,
113
114 write_uqff: Option<PathBuf>,
116
117 from_uqff: Option<String>,
119
120 #[serde(default = "default_max_seq_len")]
122 max_seq_len: usize,
123
124 #[serde(default = "default_max_batch_size")]
126 max_batch_size: usize,
127
128 hf_cache_path: Option<PathBuf>,
130 },
131
132 Lora {
134 model_id: Option<String>,
136
137 adapter_model_ids: String,
139
140 arch: Option<NormalLoaderType>,
142
143 #[serde(default = "default_dtype")]
145 dtype: ModelDType,
146
147 topology: Option<String>,
149
150 write_uqff: Option<PathBuf>,
152
153 from_uqff: Option<String>,
155
156 #[serde(default = "default_max_seq_len")]
158 max_seq_len: usize,
159
160 #[serde(default = "default_max_batch_size")]
162 max_batch_size: usize,
163
164 hf_cache_path: Option<PathBuf>,
166 },
167
168 #[allow(clippy::upper_case_acronyms)]
170 GGUF {
171 tok_model_id: String,
175
176 quantized_model_id: String,
179
180 quantized_filename: String,
183
184 #[serde(default = "default_dtype")]
186 dtype: ModelDType,
187
188 topology: Option<String>,
190
191 #[serde(default = "default_max_seq_len")]
193 max_seq_len: usize,
194
195 #[serde(default = "default_max_batch_size")]
197 max_batch_size: usize,
198 },
199
200 XLoraGGUF {
202 tok_model_id: Option<String>,
206
207 quantized_model_id: String,
210
211 quantized_filename: String,
214
215 xlora_model_id: String,
217
218 order: String,
220
221 tgt_non_granular_index: Option<usize>,
224
225 #[serde(default = "default_dtype")]
227 dtype: ModelDType,
228
229 topology: Option<String>,
231
232 #[serde(default = "default_max_seq_len")]
234 max_seq_len: usize,
235
236 #[serde(default = "default_max_batch_size")]
238 max_batch_size: usize,
239 },
240
241 LoraGGUF {
243 tok_model_id: Option<String>,
247
248 quantized_model_id: String,
251
252 quantized_filename: String,
255
256 adapters_model_id: String,
258
259 order: String,
261
262 #[serde(default = "default_dtype")]
264 dtype: ModelDType,
265
266 topology: Option<String>,
268
269 #[serde(default = "default_max_seq_len")]
271 max_seq_len: usize,
272
273 #[serde(default = "default_max_batch_size")]
275 max_batch_size: usize,
276 },
277
278 #[allow(clippy::upper_case_acronyms)]
280 GGML {
281 tok_model_id: String,
283
284 quantized_model_id: String,
287
288 quantized_filename: String,
290
291 #[serde(default = "default_one")]
293 gqa: usize,
294
295 #[serde(default = "default_dtype")]
297 dtype: ModelDType,
298
299 topology: Option<String>,
301
302 #[serde(default = "default_max_seq_len")]
304 max_seq_len: usize,
305
306 #[serde(default = "default_max_batch_size")]
308 max_batch_size: usize,
309 },
310
311 XLoraGGML {
313 tok_model_id: Option<String>,
315
316 quantized_model_id: String,
319
320 quantized_filename: String,
322
323 xlora_model_id: String,
325
326 order: String,
328
329 tgt_non_granular_index: Option<usize>,
332
333 #[serde(default = "default_one")]
335 gqa: usize,
336
337 #[serde(default = "default_dtype")]
339 dtype: ModelDType,
340
341 topology: Option<String>,
343
344 #[serde(default = "default_max_seq_len")]
346 max_seq_len: usize,
347
348 #[serde(default = "default_max_batch_size")]
350 max_batch_size: usize,
351 },
352
353 LoraGGML {
355 tok_model_id: Option<String>,
357
358 quantized_model_id: String,
361
362 quantized_filename: String,
364
365 adapters_model_id: String,
367
368 order: String,
370
371 #[serde(default = "default_one")]
373 gqa: usize,
374
375 #[serde(default = "default_dtype")]
377 dtype: ModelDType,
378
379 topology: Option<String>,
381
382 #[serde(default = "default_max_seq_len")]
384 max_seq_len: usize,
385
386 #[serde(default = "default_max_batch_size")]
388 max_batch_size: usize,
389 },
390
391 VisionPlain {
393 model_id: String,
395
396 arch: Option<VisionLoaderType>,
398
399 #[serde(default = "default_dtype")]
401 dtype: ModelDType,
402
403 topology: Option<String>,
405
406 write_uqff: Option<PathBuf>,
408
409 from_uqff: Option<String>,
411
412 max_edge: Option<u32>,
415
416 calibration_file: Option<PathBuf>,
418
419 imatrix: Option<PathBuf>,
421
422 #[serde(default = "default_max_seq_len")]
424 max_seq_len: usize,
425
426 #[serde(default = "default_max_batch_size")]
428 max_batch_size: usize,
429
430 #[serde(default = "default_max_num_images")]
432 max_num_images: usize,
433
434 #[serde(default = "default_max_image_length")]
437 max_image_length: usize,
438
439 hf_cache_path: Option<PathBuf>,
441 },
442}
443
444#[derive(Deserialize)]
445pub struct SpeculativeTomlModelSelected {
446 gamma: usize,
448
449 draft_model: TomlModelSelected,
451}
452
453#[derive(Deserialize)]
454pub struct AnyMoeTomlModelSelected {
455 config: AnyMoeConfig,
457
458 dataset_json: String,
460
461 prefix: String,
463
464 mlp: String,
466
467 model_ids: Vec<String>,
469
470 #[serde(default = "default_empty_vec_usize")]
472 layers: Vec<usize>,
473}
474
475#[derive(Deserialize)]
476pub struct TomlSelector {
477 tokenizer_json: Option<String>,
479
480 model: TomlModelSelected,
482
483 speculative: Option<SpeculativeTomlModelSelected>,
485
486 anymoe: Option<AnyMoeTomlModelSelected>,
488}
489
490#[derive(Clone)]
491struct TomlLoaderInnerParams {
492 chat_template: Option<String>,
493 no_kv_cache: bool,
494 tokenizer_json: Option<String>,
495 jinja_explicit: Option<String>,
496}
497
498pub struct TomlLoaderArgs {
499 pub chat_template: Option<String>,
500 pub no_kv_cache: bool,
501 pub jinja_explicit: Option<String>,
502}
503
504pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
505 match model.model {
506 TomlModelSelected::Plain { dtype, .. }
507 | TomlModelSelected::Lora { dtype, .. }
508 | TomlModelSelected::XLora { dtype, .. }
509 | TomlModelSelected::VisionPlain { dtype, .. }
510 | TomlModelSelected::GGUF { dtype, .. }
511 | TomlModelSelected::GGML { dtype, .. }
512 | TomlModelSelected::XLoraGGUF { dtype, .. }
513 | TomlModelSelected::XLoraGGML { dtype, .. }
514 | TomlModelSelected::LoraGGUF { dtype, .. }
515 | TomlModelSelected::LoraGGML { dtype, .. } => dtype,
516 }
517}
518
519pub fn get_toml_selected_model_device_map_params(
520 model: &TomlSelector,
521) -> anyhow::Result<AutoDeviceMapParams> {
522 match model.model {
523 TomlModelSelected::Plain {
524 max_seq_len,
525 max_batch_size,
526 ..
527 }
528 | TomlModelSelected::Lora {
529 max_seq_len,
530 max_batch_size,
531 ..
532 }
533 | TomlModelSelected::XLora {
534 max_seq_len,
535 max_batch_size,
536 ..
537 }
538 | TomlModelSelected::GGML {
539 max_seq_len,
540 max_batch_size,
541 ..
542 }
543 | TomlModelSelected::GGUF {
544 max_seq_len,
545 max_batch_size,
546 ..
547 }
548 | TomlModelSelected::XLoraGGUF {
549 max_seq_len,
550 max_batch_size,
551 ..
552 }
553 | TomlModelSelected::XLoraGGML {
554 max_seq_len,
555 max_batch_size,
556 ..
557 }
558 | TomlModelSelected::LoraGGUF {
559 max_seq_len,
560 max_batch_size,
561 ..
562 }
563 | TomlModelSelected::LoraGGML {
564 max_seq_len,
565 max_batch_size,
566 ..
567 } => Ok(AutoDeviceMapParams::Text {
568 max_seq_len,
569 max_batch_size,
570 }),
571 TomlModelSelected::VisionPlain {
572 max_seq_len,
573 max_batch_size,
574 max_image_length,
575 max_num_images,
576 ..
577 } => Ok(AutoDeviceMapParams::Vision {
578 max_seq_len,
579 max_batch_size,
580 max_image_shape: (max_image_length, max_image_length),
581 max_num_images,
582 }),
583 }
584}
585
586fn loader_from_selected(
587 args: TomlLoaderInnerParams,
588 model: TomlModelSelected,
589) -> anyhow::Result<Box<dyn Loader>> {
590 let loader: Box<dyn Loader> = match model {
591 TomlModelSelected::Plain {
592 model_id,
593 arch,
594 dtype: _,
595 topology,
596 organization,
597 write_uqff,
598 from_uqff,
599 imatrix,
600 calibration_file,
601 max_seq_len: _,
602 max_batch_size: _,
603 hf_cache_path,
604 } => NormalLoaderBuilder::new(
605 NormalSpecificConfig {
606 topology: Topology::from_option_path(topology)?,
607 organization: organization.unwrap_or_default(),
608 write_uqff,
609 from_uqff: from_uqff.map(|x| {
610 x.split(UQFF_MULTI_FILE_DELIMITER)
611 .map(PathBuf::from_str)
612 .map(|x| x.unwrap())
613 .collect::<Vec<_>>()
614 }),
615 imatrix,
616 calibration_file,
617 hf_cache_path,
618 matformer_config_path: None,
619 matformer_slice_name: None,
620 },
621 args.chat_template,
622 args.tokenizer_json,
623 Some(model_id),
624 args.no_kv_cache,
625 args.jinja_explicit,
626 )
627 .build(arch)?,
628 TomlModelSelected::XLora {
629 model_id,
630 xlora_model_id,
631 order,
632 tgt_non_granular_index,
633 arch,
634 dtype: _,
635 topology,
636 write_uqff,
637 from_uqff,
638 max_seq_len: _,
639 max_batch_size: _,
640 hf_cache_path,
641 } => NormalLoaderBuilder::new(
642 NormalSpecificConfig {
643 topology: Topology::from_option_path(topology)?,
644 organization: Default::default(),
645 write_uqff,
646 from_uqff: from_uqff.map(|x| {
647 x.split(UQFF_MULTI_FILE_DELIMITER)
648 .map(PathBuf::from_str)
649 .map(|x| x.unwrap())
650 .collect::<Vec<_>>()
651 }),
652 imatrix: None,
653 calibration_file: None,
654 hf_cache_path,
655 matformer_config_path: None,
656 matformer_slice_name: None,
657 },
658 args.chat_template,
659 args.tokenizer_json,
660 model_id,
661 args.no_kv_cache,
662 args.jinja_explicit,
663 )
664 .with_xlora(
665 xlora_model_id,
666 serde_json::from_reader(
667 File::open(order.clone())
668 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
669 )?,
670 args.no_kv_cache,
671 tgt_non_granular_index,
672 )
673 .build(arch)?,
674 TomlModelSelected::Lora {
675 model_id,
676 adapter_model_ids,
677 arch,
678 dtype: _,
679 topology,
680 write_uqff,
681 from_uqff,
682 max_seq_len: _,
683 max_batch_size: _,
684 hf_cache_path,
685 } => NormalLoaderBuilder::new(
686 NormalSpecificConfig {
687 topology: Topology::from_option_path(topology)?,
688 organization: Default::default(),
689 write_uqff,
690 from_uqff: from_uqff.map(|x| {
691 x.split(UQFF_MULTI_FILE_DELIMITER)
692 .map(PathBuf::from_str)
693 .map(|x| x.unwrap())
694 .collect::<Vec<_>>()
695 }),
696 imatrix: None,
697 calibration_file: None,
698 hf_cache_path,
699 matformer_config_path: None,
700 matformer_slice_name: None,
701 },
702 args.chat_template,
703 args.tokenizer_json,
704 model_id,
705 args.no_kv_cache,
706 args.jinja_explicit,
707 )
708 .with_lora(
709 adapter_model_ids
710 .split(MULTI_LORA_DELIMITER)
711 .map(ToString::to_string)
712 .collect(),
713 )
714 .build(arch)?,
715 TomlModelSelected::GGUF {
716 tok_model_id,
717 quantized_model_id,
718 quantized_filename,
719 topology,
720 dtype: _,
721 max_seq_len: _,
722 max_batch_size: _,
723 } => GGUFLoaderBuilder::new(
724 args.chat_template,
725 Some(tok_model_id),
726 quantized_model_id,
727 quantized_filename
728 .split(GGUF_MULTI_FILE_DELIMITER)
729 .map(ToOwned::to_owned)
730 .collect::<Vec<_>>(),
731 GGUFSpecificConfig {
732 topology: Topology::from_option_path(topology)?,
733 },
734 args.no_kv_cache,
735 args.jinja_explicit,
736 )
737 .build(),
738 TomlModelSelected::XLoraGGUF {
739 tok_model_id,
740 quantized_model_id,
741 quantized_filename,
742 xlora_model_id,
743 order,
744 tgt_non_granular_index,
745 topology,
746 dtype: _,
747 max_seq_len: _,
748 max_batch_size: _,
749 } => GGUFLoaderBuilder::new(
750 args.chat_template,
751 tok_model_id,
752 quantized_model_id,
753 quantized_filename
754 .split(GGUF_MULTI_FILE_DELIMITER)
755 .map(ToOwned::to_owned)
756 .collect::<Vec<_>>(),
757 GGUFSpecificConfig {
758 topology: Topology::from_option_path(topology)?,
759 },
760 args.no_kv_cache,
761 args.jinja_explicit,
762 )
763 .with_xlora(
764 xlora_model_id,
765 serde_json::from_reader(
766 File::open(order.clone())
767 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
768 )?,
769 args.no_kv_cache,
770 tgt_non_granular_index,
771 )
772 .build(),
773 TomlModelSelected::LoraGGUF {
774 tok_model_id,
775 quantized_model_id,
776 quantized_filename,
777 adapters_model_id,
778 order,
779 topology,
780 ..
781 } => GGUFLoaderBuilder::new(
782 args.chat_template,
783 tok_model_id,
784 quantized_model_id,
785 quantized_filename
786 .split(GGUF_MULTI_FILE_DELIMITER)
787 .map(ToOwned::to_owned)
788 .collect::<Vec<_>>(),
789 GGUFSpecificConfig {
790 topology: Topology::from_option_path(topology)?,
791 },
792 args.no_kv_cache,
793 args.jinja_explicit,
794 )
795 .with_lora(
796 adapters_model_id,
797 serde_json::from_reader(
798 File::open(order.clone())
799 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
800 )?,
801 )
802 .build(),
803 TomlModelSelected::GGML {
804 tok_model_id,
805 quantized_model_id,
806 quantized_filename,
807 gqa,
808 topology,
809 dtype: _,
810 max_seq_len: _,
811 max_batch_size: _,
812 } => GGMLLoaderBuilder::new(
813 GGMLSpecificConfig {
814 gqa,
815 topology: Topology::from_option_path(topology)?,
816 },
817 args.chat_template,
818 args.tokenizer_json,
819 Some(tok_model_id),
820 quantized_model_id,
821 quantized_filename,
822 args.no_kv_cache,
823 args.jinja_explicit,
824 )
825 .build(),
826 TomlModelSelected::XLoraGGML {
827 tok_model_id,
828 quantized_model_id,
829 quantized_filename,
830 xlora_model_id,
831 order,
832 tgt_non_granular_index,
833 gqa,
834 topology,
835 dtype: _,
836 max_seq_len: _,
837 max_batch_size: _,
838 } => GGMLLoaderBuilder::new(
839 GGMLSpecificConfig {
840 gqa,
841 topology: Topology::from_option_path(topology)?,
842 },
843 args.chat_template,
844 args.tokenizer_json,
845 tok_model_id,
846 quantized_model_id,
847 quantized_filename,
848 args.no_kv_cache,
849 args.jinja_explicit,
850 )
851 .with_xlora(
852 xlora_model_id,
853 serde_json::from_reader(
854 File::open(order.clone())
855 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
856 )?,
857 args.no_kv_cache,
858 tgt_non_granular_index,
859 )
860 .build(),
861 TomlModelSelected::LoraGGML {
862 tok_model_id,
863 quantized_model_id,
864 quantized_filename,
865 adapters_model_id,
866 order,
867 gqa,
868 topology,
869 dtype: _,
870 max_seq_len: _,
871 max_batch_size: _,
872 } => GGMLLoaderBuilder::new(
873 GGMLSpecificConfig {
874 gqa,
875 topology: Topology::from_option_path(topology)?,
876 },
877 args.chat_template,
878 args.tokenizer_json,
879 tok_model_id,
880 quantized_model_id,
881 quantized_filename,
882 args.no_kv_cache,
883 args.jinja_explicit,
884 )
885 .with_lora(
886 adapters_model_id,
887 serde_json::from_reader(
888 File::open(order.clone())
889 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
890 )?,
891 )
892 .build(),
893 TomlModelSelected::VisionPlain {
894 model_id,
895 arch,
896 dtype: _,
897 topology,
898 write_uqff,
899 from_uqff,
900 max_edge,
901 calibration_file,
902 max_seq_len: _,
903 max_batch_size: _,
904 max_num_images: _,
905 max_image_length: _,
906 imatrix,
907 hf_cache_path,
908 } => VisionLoaderBuilder::new(
909 VisionSpecificConfig {
910 topology: Topology::from_option_path(topology)?,
911 write_uqff,
912 from_uqff: from_uqff.map(|x| {
913 x.split(UQFF_MULTI_FILE_DELIMITER)
914 .map(PathBuf::from_str)
915 .map(|x| x.unwrap())
916 .collect::<Vec<_>>()
917 }),
918 max_edge,
919 calibration_file,
920 imatrix,
921 hf_cache_path,
922 matformer_config_path: None,
923 matformer_slice_name: None,
924 },
925 args.chat_template,
926 args.tokenizer_json,
927 Some(model_id),
928 args.jinja_explicit,
929 )
930 .build(arch),
931 };
932 Ok(loader)
933}
934
935impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
936 type Error = anyhow::Error;
937 fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
938 let (selector, args) = self;
939 let args = TomlLoaderInnerParams {
940 chat_template: args.chat_template,
941 no_kv_cache: args.no_kv_cache,
942 tokenizer_json: selector.tokenizer_json,
943 jinja_explicit: args.jinja_explicit,
944 };
945 let loader = loader_from_selected(args.clone(), selector.model)?;
946 let loader = if let Some(speculative) = selector.speculative {
947 let draft_loader = loader_from_selected(args, speculative.draft_model)?;
948 Box::new(SpeculativeLoader {
949 target: loader,
950 draft: draft_loader,
951 config: SpeculativeConfig {
952 gamma: speculative.gamma,
953 },
954 })
955 } else {
956 loader
957 };
958 let loader = if let Some(AnyMoeTomlModelSelected {
959 config,
960 dataset_json,
961 prefix,
962 mlp,
963 model_ids,
964 layers,
965 }) = selector.anymoe
966 {
967 Box::new(AnyMoeLoader {
968 target: loader,
969 config,
970 path: dataset_json,
971 prefix,
972 mlp,
973 model_ids,
974 layers,
975 })
976 } else {
977 loader
978 };
979 Ok(loader)
980 }
981}