1use std::{fs::File, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7 amoe::AnyMoeConfig,
8 pipeline::{EmbeddingLoaderType, IsqOrganization},
9 AnyMoeLoader, AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig,
10 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
11 ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
12 SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
13 GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
14};
15
16fn default_one() -> usize {
17 1
18}
19
20fn default_dtype() -> ModelDType {
21 ModelDType::Auto
22}
23
24fn default_empty_vec_usize() -> Vec<usize> {
25 Vec::new()
26}
27
28fn default_max_seq_len() -> usize {
29 AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
30}
31
32fn default_max_batch_size() -> usize {
33 AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
34}
35
36fn default_max_num_images() -> usize {
37 AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
38}
39
40fn default_max_image_length() -> usize {
41 AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(untagged)]
46pub enum TomlModelSelected {
47 Plain {
49 model_id: String,
51
52 arch: Option<NormalLoaderType>,
54
55 #[serde(default = "default_dtype")]
57 dtype: ModelDType,
58
59 topology: Option<String>,
61
62 organization: Option<IsqOrganization>,
64
65 write_uqff: Option<PathBuf>,
67
68 from_uqff: Option<String>,
70
71 imatrix: Option<PathBuf>,
74
75 calibration_file: Option<PathBuf>,
78
79 #[serde(default = "default_max_seq_len")]
81 max_seq_len: usize,
82
83 #[serde(default = "default_max_batch_size")]
85 max_batch_size: usize,
86
87 hf_cache_path: Option<PathBuf>,
89 },
90
91 XLora {
93 model_id: Option<String>,
95
96 xlora_model_id: String,
98
99 order: String,
101
102 tgt_non_granular_index: Option<usize>,
105
106 arch: Option<NormalLoaderType>,
108
109 #[serde(default = "default_dtype")]
111 dtype: ModelDType,
112
113 topology: Option<String>,
115
116 write_uqff: Option<PathBuf>,
118
119 from_uqff: Option<String>,
121
122 #[serde(default = "default_max_seq_len")]
124 max_seq_len: usize,
125
126 #[serde(default = "default_max_batch_size")]
128 max_batch_size: usize,
129
130 hf_cache_path: Option<PathBuf>,
132 },
133
134 Lora {
136 model_id: Option<String>,
138
139 adapter_model_ids: String,
141
142 arch: Option<NormalLoaderType>,
144
145 #[serde(default = "default_dtype")]
147 dtype: ModelDType,
148
149 topology: Option<String>,
151
152 write_uqff: Option<PathBuf>,
154
155 from_uqff: Option<String>,
157
158 #[serde(default = "default_max_seq_len")]
160 max_seq_len: usize,
161
162 #[serde(default = "default_max_batch_size")]
164 max_batch_size: usize,
165
166 hf_cache_path: Option<PathBuf>,
168 },
169
170 #[allow(clippy::upper_case_acronyms)]
172 GGUF {
173 tok_model_id: String,
177
178 quantized_model_id: String,
181
182 quantized_filename: String,
185
186 #[serde(default = "default_dtype")]
188 dtype: ModelDType,
189
190 topology: Option<String>,
192
193 #[serde(default = "default_max_seq_len")]
195 max_seq_len: usize,
196
197 #[serde(default = "default_max_batch_size")]
199 max_batch_size: usize,
200 },
201
202 XLoraGGUF {
204 tok_model_id: Option<String>,
208
209 quantized_model_id: String,
212
213 quantized_filename: String,
216
217 xlora_model_id: String,
219
220 order: String,
222
223 tgt_non_granular_index: Option<usize>,
226
227 #[serde(default = "default_dtype")]
229 dtype: ModelDType,
230
231 topology: Option<String>,
233
234 #[serde(default = "default_max_seq_len")]
236 max_seq_len: usize,
237
238 #[serde(default = "default_max_batch_size")]
240 max_batch_size: usize,
241 },
242
243 LoraGGUF {
245 tok_model_id: Option<String>,
249
250 quantized_model_id: String,
253
254 quantized_filename: String,
257
258 adapters_model_id: String,
260
261 order: String,
263
264 #[serde(default = "default_dtype")]
266 dtype: ModelDType,
267
268 topology: Option<String>,
270
271 #[serde(default = "default_max_seq_len")]
273 max_seq_len: usize,
274
275 #[serde(default = "default_max_batch_size")]
277 max_batch_size: usize,
278 },
279
280 #[allow(clippy::upper_case_acronyms)]
282 GGML {
283 tok_model_id: String,
285
286 quantized_model_id: String,
289
290 quantized_filename: String,
292
293 #[serde(default = "default_one")]
295 gqa: usize,
296
297 #[serde(default = "default_dtype")]
299 dtype: ModelDType,
300
301 topology: Option<String>,
303
304 #[serde(default = "default_max_seq_len")]
306 max_seq_len: usize,
307
308 #[serde(default = "default_max_batch_size")]
310 max_batch_size: usize,
311 },
312
313 XLoraGGML {
315 tok_model_id: Option<String>,
317
318 quantized_model_id: String,
321
322 quantized_filename: String,
324
325 xlora_model_id: String,
327
328 order: String,
330
331 tgt_non_granular_index: Option<usize>,
334
335 #[serde(default = "default_one")]
337 gqa: usize,
338
339 #[serde(default = "default_dtype")]
341 dtype: ModelDType,
342
343 topology: Option<String>,
345
346 #[serde(default = "default_max_seq_len")]
348 max_seq_len: usize,
349
350 #[serde(default = "default_max_batch_size")]
352 max_batch_size: usize,
353 },
354
355 LoraGGML {
357 tok_model_id: Option<String>,
359
360 quantized_model_id: String,
363
364 quantized_filename: String,
366
367 adapters_model_id: String,
369
370 order: String,
372
373 #[serde(default = "default_one")]
375 gqa: usize,
376
377 #[serde(default = "default_dtype")]
379 dtype: ModelDType,
380
381 topology: Option<String>,
383
384 #[serde(default = "default_max_seq_len")]
386 max_seq_len: usize,
387
388 #[serde(default = "default_max_batch_size")]
390 max_batch_size: usize,
391 },
392
393 VisionPlain {
395 model_id: String,
397
398 arch: Option<VisionLoaderType>,
400
401 #[serde(default = "default_dtype")]
403 dtype: ModelDType,
404
405 topology: Option<String>,
407
408 write_uqff: Option<PathBuf>,
410
411 from_uqff: Option<String>,
413
414 max_edge: Option<u32>,
417
418 calibration_file: Option<PathBuf>,
420
421 imatrix: Option<PathBuf>,
423
424 #[serde(default = "default_max_seq_len")]
426 max_seq_len: usize,
427
428 #[serde(default = "default_max_batch_size")]
430 max_batch_size: usize,
431
432 #[serde(default = "default_max_num_images")]
434 max_num_images: usize,
435
436 #[serde(default = "default_max_image_length")]
439 max_image_length: usize,
440
441 hf_cache_path: Option<PathBuf>,
443 },
444
445 Embedding {
447 model_id: String,
449
450 #[serde(default)]
452 tokenizer_json: Option<String>,
453
454 #[serde(default)]
456 arch: Option<EmbeddingLoaderType>,
457
458 #[serde(default = "default_dtype")]
460 dtype: ModelDType,
461
462 #[serde(default)]
464 topology: Option<String>,
465
466 #[serde(default)]
468 write_uqff: Option<PathBuf>,
469
470 #[serde(default)]
472 from_uqff: Option<String>,
473
474 #[serde(default)]
476 hf_cache_path: Option<PathBuf>,
477 },
478}
479
480#[derive(Deserialize)]
481pub struct SpeculativeTomlModelSelected {
482 gamma: usize,
484
485 draft_model: TomlModelSelected,
487}
488
489#[derive(Deserialize)]
490pub struct AnyMoeTomlModelSelected {
491 config: AnyMoeConfig,
493
494 dataset_json: String,
496
497 prefix: String,
499
500 mlp: String,
502
503 model_ids: Vec<String>,
505
506 #[serde(default = "default_empty_vec_usize")]
508 layers: Vec<usize>,
509}
510
511#[derive(Deserialize)]
512pub struct TomlSelector {
513 tokenizer_json: Option<String>,
515
516 model: TomlModelSelected,
518
519 speculative: Option<SpeculativeTomlModelSelected>,
521
522 anymoe: Option<AnyMoeTomlModelSelected>,
524}
525
526#[derive(Clone)]
527struct TomlLoaderInnerParams {
528 chat_template: Option<String>,
529 no_kv_cache: bool,
530 tokenizer_json: Option<String>,
531 jinja_explicit: Option<String>,
532}
533
534pub struct TomlLoaderArgs {
535 pub chat_template: Option<String>,
536 pub no_kv_cache: bool,
537 pub jinja_explicit: Option<String>,
538}
539
540pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
541 match model.model {
542 TomlModelSelected::Plain { dtype, .. }
543 | TomlModelSelected::Lora { dtype, .. }
544 | TomlModelSelected::XLora { dtype, .. }
545 | TomlModelSelected::VisionPlain { dtype, .. }
546 | TomlModelSelected::GGUF { dtype, .. }
547 | TomlModelSelected::GGML { dtype, .. }
548 | TomlModelSelected::XLoraGGUF { dtype, .. }
549 | TomlModelSelected::XLoraGGML { dtype, .. }
550 | TomlModelSelected::LoraGGUF { dtype, .. }
551 | TomlModelSelected::LoraGGML { dtype, .. }
552 | TomlModelSelected::Embedding { dtype, .. } => dtype,
553 }
554}
555
556pub fn get_toml_selected_model_device_map_params(
557 model: &TomlSelector,
558) -> anyhow::Result<AutoDeviceMapParams> {
559 match model.model {
560 TomlModelSelected::Plain {
561 max_seq_len,
562 max_batch_size,
563 ..
564 }
565 | TomlModelSelected::Lora {
566 max_seq_len,
567 max_batch_size,
568 ..
569 }
570 | TomlModelSelected::XLora {
571 max_seq_len,
572 max_batch_size,
573 ..
574 }
575 | TomlModelSelected::GGML {
576 max_seq_len,
577 max_batch_size,
578 ..
579 }
580 | TomlModelSelected::GGUF {
581 max_seq_len,
582 max_batch_size,
583 ..
584 }
585 | TomlModelSelected::XLoraGGUF {
586 max_seq_len,
587 max_batch_size,
588 ..
589 }
590 | TomlModelSelected::XLoraGGML {
591 max_seq_len,
592 max_batch_size,
593 ..
594 }
595 | TomlModelSelected::LoraGGUF {
596 max_seq_len,
597 max_batch_size,
598 ..
599 }
600 | TomlModelSelected::LoraGGML {
601 max_seq_len,
602 max_batch_size,
603 ..
604 } => Ok(AutoDeviceMapParams::Text {
605 max_seq_len,
606 max_batch_size,
607 }),
608 TomlModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
609 TomlModelSelected::VisionPlain {
610 max_seq_len,
611 max_batch_size,
612 max_image_length,
613 max_num_images,
614 ..
615 } => Ok(AutoDeviceMapParams::Vision {
616 max_seq_len,
617 max_batch_size,
618 max_image_shape: (max_image_length, max_image_length),
619 max_num_images,
620 }),
621 }
622}
623
624fn loader_from_selected(
625 args: TomlLoaderInnerParams,
626 model: TomlModelSelected,
627) -> anyhow::Result<Box<dyn Loader>> {
628 let loader: Box<dyn Loader> = match model {
629 TomlModelSelected::Plain {
630 model_id,
631 arch,
632 dtype: _,
633 topology,
634 organization,
635 write_uqff,
636 from_uqff,
637 imatrix,
638 calibration_file,
639 max_seq_len: _,
640 max_batch_size: _,
641 hf_cache_path,
642 } => NormalLoaderBuilder::new(
643 NormalSpecificConfig {
644 topology: Topology::from_option_path(topology)?,
645 organization: organization.unwrap_or_default(),
646 write_uqff,
647 from_uqff: from_uqff.map(|x| {
648 x.split(UQFF_MULTI_FILE_DELIMITER)
649 .map(PathBuf::from_str)
650 .map(|x| x.unwrap())
651 .collect::<Vec<_>>()
652 }),
653 imatrix,
654 calibration_file,
655 hf_cache_path,
656 matformer_config_path: None,
657 matformer_slice_name: None,
658 },
659 args.chat_template,
660 args.tokenizer_json,
661 Some(model_id),
662 args.no_kv_cache,
663 args.jinja_explicit,
664 )
665 .build(arch)?,
666 TomlModelSelected::XLora {
667 model_id,
668 xlora_model_id,
669 order,
670 tgt_non_granular_index,
671 arch,
672 dtype: _,
673 topology,
674 write_uqff,
675 from_uqff,
676 max_seq_len: _,
677 max_batch_size: _,
678 hf_cache_path,
679 } => NormalLoaderBuilder::new(
680 NormalSpecificConfig {
681 topology: Topology::from_option_path(topology)?,
682 organization: Default::default(),
683 write_uqff,
684 from_uqff: from_uqff.map(|x| {
685 x.split(UQFF_MULTI_FILE_DELIMITER)
686 .map(PathBuf::from_str)
687 .map(|x| x.unwrap())
688 .collect::<Vec<_>>()
689 }),
690 imatrix: None,
691 calibration_file: None,
692 hf_cache_path,
693 matformer_config_path: None,
694 matformer_slice_name: None,
695 },
696 args.chat_template,
697 args.tokenizer_json,
698 model_id,
699 args.no_kv_cache,
700 args.jinja_explicit,
701 )
702 .with_xlora(
703 xlora_model_id,
704 serde_json::from_reader(
705 File::open(order.clone())
706 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
707 )?,
708 args.no_kv_cache,
709 tgt_non_granular_index,
710 )
711 .build(arch)?,
712 TomlModelSelected::Lora {
713 model_id,
714 adapter_model_ids,
715 arch,
716 dtype: _,
717 topology,
718 write_uqff,
719 from_uqff,
720 max_seq_len: _,
721 max_batch_size: _,
722 hf_cache_path,
723 } => NormalLoaderBuilder::new(
724 NormalSpecificConfig {
725 topology: Topology::from_option_path(topology)?,
726 organization: Default::default(),
727 write_uqff,
728 from_uqff: from_uqff.map(|x| {
729 x.split(UQFF_MULTI_FILE_DELIMITER)
730 .map(PathBuf::from_str)
731 .map(|x| x.unwrap())
732 .collect::<Vec<_>>()
733 }),
734 imatrix: None,
735 calibration_file: None,
736 hf_cache_path,
737 matformer_config_path: None,
738 matformer_slice_name: None,
739 },
740 args.chat_template,
741 args.tokenizer_json,
742 model_id,
743 args.no_kv_cache,
744 args.jinja_explicit,
745 )
746 .with_lora(
747 adapter_model_ids
748 .split(MULTI_LORA_DELIMITER)
749 .map(ToString::to_string)
750 .collect(),
751 )
752 .build(arch)?,
753 TomlModelSelected::GGUF {
754 tok_model_id,
755 quantized_model_id,
756 quantized_filename,
757 topology,
758 dtype: _,
759 max_seq_len: _,
760 max_batch_size: _,
761 } => GGUFLoaderBuilder::new(
762 args.chat_template,
763 Some(tok_model_id),
764 quantized_model_id,
765 quantized_filename
766 .split(GGUF_MULTI_FILE_DELIMITER)
767 .map(ToOwned::to_owned)
768 .collect::<Vec<_>>(),
769 GGUFSpecificConfig {
770 topology: Topology::from_option_path(topology)?,
771 },
772 args.no_kv_cache,
773 args.jinja_explicit,
774 )
775 .build(),
776 TomlModelSelected::XLoraGGUF {
777 tok_model_id,
778 quantized_model_id,
779 quantized_filename,
780 xlora_model_id,
781 order,
782 tgt_non_granular_index,
783 topology,
784 dtype: _,
785 max_seq_len: _,
786 max_batch_size: _,
787 } => GGUFLoaderBuilder::new(
788 args.chat_template,
789 tok_model_id,
790 quantized_model_id,
791 quantized_filename
792 .split(GGUF_MULTI_FILE_DELIMITER)
793 .map(ToOwned::to_owned)
794 .collect::<Vec<_>>(),
795 GGUFSpecificConfig {
796 topology: Topology::from_option_path(topology)?,
797 },
798 args.no_kv_cache,
799 args.jinja_explicit,
800 )
801 .with_xlora(
802 xlora_model_id,
803 serde_json::from_reader(
804 File::open(order.clone())
805 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
806 )?,
807 args.no_kv_cache,
808 tgt_non_granular_index,
809 )
810 .build(),
811 TomlModelSelected::LoraGGUF {
812 tok_model_id,
813 quantized_model_id,
814 quantized_filename,
815 adapters_model_id,
816 order,
817 topology,
818 ..
819 } => GGUFLoaderBuilder::new(
820 args.chat_template,
821 tok_model_id,
822 quantized_model_id,
823 quantized_filename
824 .split(GGUF_MULTI_FILE_DELIMITER)
825 .map(ToOwned::to_owned)
826 .collect::<Vec<_>>(),
827 GGUFSpecificConfig {
828 topology: Topology::from_option_path(topology)?,
829 },
830 args.no_kv_cache,
831 args.jinja_explicit,
832 )
833 .with_lora(
834 adapters_model_id,
835 serde_json::from_reader(
836 File::open(order.clone())
837 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
838 )?,
839 )
840 .build(),
841 TomlModelSelected::GGML {
842 tok_model_id,
843 quantized_model_id,
844 quantized_filename,
845 gqa,
846 topology,
847 dtype: _,
848 max_seq_len: _,
849 max_batch_size: _,
850 } => GGMLLoaderBuilder::new(
851 GGMLSpecificConfig {
852 gqa,
853 topology: Topology::from_option_path(topology)?,
854 },
855 args.chat_template,
856 args.tokenizer_json,
857 Some(tok_model_id),
858 quantized_model_id,
859 quantized_filename,
860 args.no_kv_cache,
861 args.jinja_explicit,
862 )
863 .build(),
864 TomlModelSelected::XLoraGGML {
865 tok_model_id,
866 quantized_model_id,
867 quantized_filename,
868 xlora_model_id,
869 order,
870 tgt_non_granular_index,
871 gqa,
872 topology,
873 dtype: _,
874 max_seq_len: _,
875 max_batch_size: _,
876 } => GGMLLoaderBuilder::new(
877 GGMLSpecificConfig {
878 gqa,
879 topology: Topology::from_option_path(topology)?,
880 },
881 args.chat_template,
882 args.tokenizer_json,
883 tok_model_id,
884 quantized_model_id,
885 quantized_filename,
886 args.no_kv_cache,
887 args.jinja_explicit,
888 )
889 .with_xlora(
890 xlora_model_id,
891 serde_json::from_reader(
892 File::open(order.clone())
893 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
894 )?,
895 args.no_kv_cache,
896 tgt_non_granular_index,
897 )
898 .build(),
899 TomlModelSelected::LoraGGML {
900 tok_model_id,
901 quantized_model_id,
902 quantized_filename,
903 adapters_model_id,
904 order,
905 gqa,
906 topology,
907 dtype: _,
908 max_seq_len: _,
909 max_batch_size: _,
910 } => GGMLLoaderBuilder::new(
911 GGMLSpecificConfig {
912 gqa,
913 topology: Topology::from_option_path(topology)?,
914 },
915 args.chat_template,
916 args.tokenizer_json,
917 tok_model_id,
918 quantized_model_id,
919 quantized_filename,
920 args.no_kv_cache,
921 args.jinja_explicit,
922 )
923 .with_lora(
924 adapters_model_id,
925 serde_json::from_reader(
926 File::open(order.clone())
927 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
928 )?,
929 )
930 .build(),
931 TomlModelSelected::VisionPlain {
932 model_id,
933 arch,
934 dtype: _,
935 topology,
936 write_uqff,
937 from_uqff,
938 max_edge,
939 calibration_file,
940 max_seq_len: _,
941 max_batch_size: _,
942 max_num_images: _,
943 max_image_length: _,
944 imatrix,
945 hf_cache_path,
946 } => VisionLoaderBuilder::new(
947 VisionSpecificConfig {
948 topology: Topology::from_option_path(topology)?,
949 write_uqff,
950 from_uqff: from_uqff.map(|x| {
951 x.split(UQFF_MULTI_FILE_DELIMITER)
952 .map(PathBuf::from_str)
953 .map(|x| x.unwrap())
954 .collect::<Vec<_>>()
955 }),
956 max_edge,
957 calibration_file,
958 imatrix,
959 hf_cache_path,
960 matformer_config_path: None,
961 matformer_slice_name: None,
962 },
963 args.chat_template,
964 args.tokenizer_json,
965 Some(model_id),
966 args.jinja_explicit,
967 )
968 .build(arch),
969 TomlModelSelected::Embedding {
970 model_id,
971 tokenizer_json,
972 arch,
973 dtype: _,
974 topology,
975 write_uqff,
976 from_uqff,
977 hf_cache_path,
978 } => EmbeddingLoaderBuilder::new(
979 EmbeddingSpecificConfig {
980 topology: Topology::from_option_path(topology)?,
981 write_uqff,
982 from_uqff: from_uqff.map(|x| {
983 x.split(UQFF_MULTI_FILE_DELIMITER)
984 .map(PathBuf::from_str)
985 .map(|x| x.unwrap())
986 .collect::<Vec<_>>()
987 }),
988 hf_cache_path,
989 },
990 tokenizer_json,
991 Some(model_id),
992 )
993 .build(arch),
994 };
995 Ok(loader)
996}
997
998impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
999 type Error = anyhow::Error;
1000 fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
1001 let (selector, args) = self;
1002 let args = TomlLoaderInnerParams {
1003 chat_template: args.chat_template,
1004 no_kv_cache: args.no_kv_cache,
1005 tokenizer_json: selector.tokenizer_json,
1006 jinja_explicit: args.jinja_explicit,
1007 };
1008 let loader = loader_from_selected(args.clone(), selector.model)?;
1009 let loader = if let Some(speculative) = selector.speculative {
1010 let draft_loader = loader_from_selected(args, speculative.draft_model)?;
1011 Box::new(SpeculativeLoader {
1012 target: loader,
1013 draft: draft_loader,
1014 config: SpeculativeConfig {
1015 gamma: speculative.gamma,
1016 },
1017 })
1018 } else {
1019 loader
1020 };
1021 let loader = if let Some(AnyMoeTomlModelSelected {
1022 config,
1023 dataset_json,
1024 prefix,
1025 mlp,
1026 model_ids,
1027 layers,
1028 }) = selector.anymoe
1029 {
1030 Box::new(AnyMoeLoader {
1031 target: loader,
1032 config,
1033 path: dataset_json,
1034 prefix,
1035 mlp,
1036 model_ids,
1037 layers,
1038 })
1039 } else {
1040 loader
1041 };
1042 Ok(loader)
1043 }
1044}