1use std::{fs::File, num::NonZeroUsize, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7 amoe::AnyMoeConfig, pipeline::IsqOrganization, AnyMoeLoader, AutoDeviceMapParams,
8 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
9 ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
10 SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
11 GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
12};
13
14fn default_one() -> usize {
15 1
16}
17
18fn default_dtype() -> ModelDType {
19 ModelDType::Auto
20}
21
22fn default_empty_vec_usize() -> Vec<usize> {
23 Vec::new()
24}
25
26fn default_max_seq_len() -> usize {
27 AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
28}
29
30fn default_max_batch_size() -> usize {
31 AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
32}
33
34fn default_max_num_images() -> usize {
35 AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
36}
37
38fn default_max_image_length() -> usize {
39 AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44pub enum TomlModelSelected {
45 Plain {
47 model_id: String,
49
50 arch: Option<NormalLoaderType>,
52
53 #[serde(default = "default_dtype")]
55 dtype: ModelDType,
56
57 topology: Option<String>,
59
60 organization: Option<IsqOrganization>,
62
63 write_uqff: Option<PathBuf>,
65
66 from_uqff: Option<String>,
68
69 imatrix: Option<PathBuf>,
72
73 calibration_file: Option<PathBuf>,
76
77 #[serde(default = "default_max_seq_len")]
79 max_seq_len: usize,
80
81 #[serde(default = "default_max_batch_size")]
83 max_batch_size: usize,
84
85 hf_cache_path: Option<PathBuf>,
87 },
88
89 XLora {
91 model_id: Option<String>,
93
94 xlora_model_id: String,
96
97 order: String,
99
100 tgt_non_granular_index: Option<usize>,
103
104 arch: Option<NormalLoaderType>,
106
107 #[serde(default = "default_dtype")]
109 dtype: ModelDType,
110
111 topology: Option<String>,
113
114 write_uqff: Option<PathBuf>,
116
117 from_uqff: Option<String>,
119
120 #[serde(default = "default_max_seq_len")]
122 max_seq_len: usize,
123
124 #[serde(default = "default_max_batch_size")]
126 max_batch_size: usize,
127
128 hf_cache_path: Option<PathBuf>,
130 },
131
132 Lora {
134 model_id: Option<String>,
136
137 adapter_model_ids: String,
139
140 arch: Option<NormalLoaderType>,
142
143 #[serde(default = "default_dtype")]
145 dtype: ModelDType,
146
147 topology: Option<String>,
149
150 write_uqff: Option<PathBuf>,
152
153 from_uqff: Option<String>,
155
156 #[serde(default = "default_max_seq_len")]
158 max_seq_len: usize,
159
160 #[serde(default = "default_max_batch_size")]
162 max_batch_size: usize,
163
164 hf_cache_path: Option<PathBuf>,
166 },
167
168 #[allow(clippy::upper_case_acronyms)]
170 GGUF {
171 tok_model_id: String,
175
176 quantized_model_id: String,
179
180 quantized_filename: String,
183
184 #[serde(default = "default_dtype")]
186 dtype: ModelDType,
187
188 topology: Option<String>,
190
191 #[serde(default = "default_max_seq_len")]
193 max_seq_len: usize,
194
195 #[serde(default = "default_max_batch_size")]
197 max_batch_size: usize,
198 },
199
200 XLoraGGUF {
202 tok_model_id: Option<String>,
206
207 quantized_model_id: String,
210
211 quantized_filename: String,
214
215 xlora_model_id: String,
217
218 order: String,
220
221 tgt_non_granular_index: Option<usize>,
224
225 #[serde(default = "default_dtype")]
227 dtype: ModelDType,
228
229 topology: Option<String>,
231
232 #[serde(default = "default_max_seq_len")]
234 max_seq_len: usize,
235
236 #[serde(default = "default_max_batch_size")]
238 max_batch_size: usize,
239 },
240
241 LoraGGUF {
243 tok_model_id: Option<String>,
247
248 quantized_model_id: String,
251
252 quantized_filename: String,
255
256 adapters_model_id: String,
258
259 order: String,
261
262 #[serde(default = "default_dtype")]
264 dtype: ModelDType,
265
266 topology: Option<String>,
268
269 #[serde(default = "default_max_seq_len")]
271 max_seq_len: usize,
272
273 #[serde(default = "default_max_batch_size")]
275 max_batch_size: usize,
276 },
277
278 #[allow(clippy::upper_case_acronyms)]
280 GGML {
281 tok_model_id: String,
283
284 quantized_model_id: String,
287
288 quantized_filename: String,
290
291 #[serde(default = "default_one")]
293 gqa: usize,
294
295 #[serde(default = "default_dtype")]
297 dtype: ModelDType,
298
299 topology: Option<String>,
301
302 #[serde(default = "default_max_seq_len")]
304 max_seq_len: usize,
305
306 #[serde(default = "default_max_batch_size")]
308 max_batch_size: usize,
309 },
310
311 XLoraGGML {
313 tok_model_id: Option<String>,
315
316 quantized_model_id: String,
319
320 quantized_filename: String,
322
323 xlora_model_id: String,
325
326 order: String,
328
329 tgt_non_granular_index: Option<usize>,
332
333 #[serde(default = "default_one")]
335 gqa: usize,
336
337 #[serde(default = "default_dtype")]
339 dtype: ModelDType,
340
341 topology: Option<String>,
343
344 #[serde(default = "default_max_seq_len")]
346 max_seq_len: usize,
347
348 #[serde(default = "default_max_batch_size")]
350 max_batch_size: usize,
351 },
352
353 LoraGGML {
355 tok_model_id: Option<String>,
357
358 quantized_model_id: String,
361
362 quantized_filename: String,
364
365 adapters_model_id: String,
367
368 order: String,
370
371 #[serde(default = "default_one")]
373 gqa: usize,
374
375 #[serde(default = "default_dtype")]
377 dtype: ModelDType,
378
379 topology: Option<String>,
381
382 #[serde(default = "default_max_seq_len")]
384 max_seq_len: usize,
385
386 #[serde(default = "default_max_batch_size")]
388 max_batch_size: usize,
389 },
390
391 VisionPlain {
393 model_id: String,
395
396 arch: Option<VisionLoaderType>,
398
399 #[serde(default = "default_dtype")]
401 dtype: ModelDType,
402
403 topology: Option<String>,
405
406 write_uqff: Option<PathBuf>,
408
409 from_uqff: Option<String>,
411
412 max_edge: Option<u32>,
415
416 calibration_file: Option<PathBuf>,
418
419 imatrix: Option<PathBuf>,
421
422 #[serde(default = "default_max_seq_len")]
424 max_seq_len: usize,
425
426 #[serde(default = "default_max_batch_size")]
428 max_batch_size: usize,
429
430 #[serde(default = "default_max_num_images")]
432 max_num_images: usize,
433
434 #[serde(default = "default_max_image_length")]
437 max_image_length: usize,
438
439 hf_cache_path: Option<PathBuf>,
441 },
442}
443
444#[derive(Deserialize)]
445pub struct SpeculativeTomlModelSelected {
446 gamma: usize,
448
449 draft_model: TomlModelSelected,
451}
452
453#[derive(Deserialize)]
454pub struct AnyMoeTomlModelSelected {
455 config: AnyMoeConfig,
457
458 dataset_json: String,
460
461 prefix: String,
463
464 mlp: String,
466
467 model_ids: Vec<String>,
469
470 #[serde(default = "default_empty_vec_usize")]
472 layers: Vec<usize>,
473}
474
475#[derive(Deserialize)]
476pub struct TomlSelector {
477 tokenizer_json: Option<String>,
479
480 model: TomlModelSelected,
482
483 speculative: Option<SpeculativeTomlModelSelected>,
485
486 anymoe: Option<AnyMoeTomlModelSelected>,
488}
489
490#[derive(Clone)]
491struct TomlLoaderInnerParams {
492 chat_template: Option<String>,
493 no_kv_cache: bool,
494 tokenizer_json: Option<String>,
495 prompt_chunksize: Option<NonZeroUsize>,
496 jinja_explicit: Option<String>,
497}
498
499pub struct TomlLoaderArgs {
500 pub chat_template: Option<String>,
501 pub no_kv_cache: bool,
502 pub prompt_chunksize: Option<NonZeroUsize>,
503 pub jinja_explicit: Option<String>,
504}
505
506pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
507 match model.model {
508 TomlModelSelected::Plain { dtype, .. }
509 | TomlModelSelected::Lora { dtype, .. }
510 | TomlModelSelected::XLora { dtype, .. }
511 | TomlModelSelected::VisionPlain { dtype, .. }
512 | TomlModelSelected::GGUF { dtype, .. }
513 | TomlModelSelected::GGML { dtype, .. }
514 | TomlModelSelected::XLoraGGUF { dtype, .. }
515 | TomlModelSelected::XLoraGGML { dtype, .. }
516 | TomlModelSelected::LoraGGUF { dtype, .. }
517 | TomlModelSelected::LoraGGML { dtype, .. } => dtype,
518 }
519}
520
521pub fn get_toml_selected_model_device_map_params(
522 model: &TomlSelector,
523) -> anyhow::Result<AutoDeviceMapParams> {
524 match model.model {
525 TomlModelSelected::Plain {
526 max_seq_len,
527 max_batch_size,
528 ..
529 }
530 | TomlModelSelected::Lora {
531 max_seq_len,
532 max_batch_size,
533 ..
534 }
535 | TomlModelSelected::XLora {
536 max_seq_len,
537 max_batch_size,
538 ..
539 }
540 | TomlModelSelected::GGML {
541 max_seq_len,
542 max_batch_size,
543 ..
544 }
545 | TomlModelSelected::GGUF {
546 max_seq_len,
547 max_batch_size,
548 ..
549 }
550 | TomlModelSelected::XLoraGGUF {
551 max_seq_len,
552 max_batch_size,
553 ..
554 }
555 | TomlModelSelected::XLoraGGML {
556 max_seq_len,
557 max_batch_size,
558 ..
559 }
560 | TomlModelSelected::LoraGGUF {
561 max_seq_len,
562 max_batch_size,
563 ..
564 }
565 | TomlModelSelected::LoraGGML {
566 max_seq_len,
567 max_batch_size,
568 ..
569 } => Ok(AutoDeviceMapParams::Text {
570 max_seq_len,
571 max_batch_size,
572 }),
573 TomlModelSelected::VisionPlain {
574 max_seq_len,
575 max_batch_size,
576 max_image_length,
577 max_num_images,
578 ..
579 } => Ok(AutoDeviceMapParams::Vision {
580 max_seq_len,
581 max_batch_size,
582 max_image_shape: (max_image_length, max_image_length),
583 max_num_images,
584 }),
585 }
586}
587
588fn loader_from_selected(
589 args: TomlLoaderInnerParams,
590 model: TomlModelSelected,
591) -> anyhow::Result<Box<dyn Loader>> {
592 let loader: Box<dyn Loader> = match model {
593 TomlModelSelected::Plain {
594 model_id,
595 arch,
596 dtype: _,
597 topology,
598 organization,
599 write_uqff,
600 from_uqff,
601 imatrix,
602 calibration_file,
603 max_seq_len: _,
604 max_batch_size: _,
605 hf_cache_path,
606 } => NormalLoaderBuilder::new(
607 NormalSpecificConfig {
608 prompt_chunksize: args.prompt_chunksize,
609 topology: Topology::from_option_path(topology)?,
610 organization: organization.unwrap_or_default(),
611 write_uqff,
612 from_uqff: from_uqff.map(|x| {
613 x.split(UQFF_MULTI_FILE_DELIMITER)
614 .map(PathBuf::from_str)
615 .map(|x| x.unwrap())
616 .collect::<Vec<_>>()
617 }),
618 imatrix,
619 calibration_file,
620 hf_cache_path,
621 },
622 args.chat_template,
623 args.tokenizer_json,
624 Some(model_id),
625 args.no_kv_cache,
626 args.jinja_explicit,
627 )
628 .build(arch)?,
629 TomlModelSelected::XLora {
630 model_id,
631 xlora_model_id,
632 order,
633 tgt_non_granular_index,
634 arch,
635 dtype: _,
636 topology,
637 write_uqff,
638 from_uqff,
639 max_seq_len: _,
640 max_batch_size: _,
641 hf_cache_path,
642 } => NormalLoaderBuilder::new(
643 NormalSpecificConfig {
644 prompt_chunksize: args.prompt_chunksize,
645 topology: Topology::from_option_path(topology)?,
646 organization: Default::default(),
647 write_uqff,
648 from_uqff: from_uqff.map(|x| {
649 x.split(UQFF_MULTI_FILE_DELIMITER)
650 .map(PathBuf::from_str)
651 .map(|x| x.unwrap())
652 .collect::<Vec<_>>()
653 }),
654 imatrix: None,
655 calibration_file: None,
656 hf_cache_path,
657 },
658 args.chat_template,
659 args.tokenizer_json,
660 model_id,
661 args.no_kv_cache,
662 args.jinja_explicit,
663 )
664 .with_xlora(
665 xlora_model_id,
666 serde_json::from_reader(
667 File::open(order.clone())
668 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
669 )?,
670 args.no_kv_cache,
671 tgt_non_granular_index,
672 )
673 .build(arch)?,
674 TomlModelSelected::Lora {
675 model_id,
676 adapter_model_ids,
677 arch,
678 dtype: _,
679 topology,
680 write_uqff,
681 from_uqff,
682 max_seq_len: _,
683 max_batch_size: _,
684 hf_cache_path,
685 } => NormalLoaderBuilder::new(
686 NormalSpecificConfig {
687 prompt_chunksize: args.prompt_chunksize,
688 topology: Topology::from_option_path(topology)?,
689 organization: Default::default(),
690 write_uqff,
691 from_uqff: from_uqff.map(|x| {
692 x.split(UQFF_MULTI_FILE_DELIMITER)
693 .map(PathBuf::from_str)
694 .map(|x| x.unwrap())
695 .collect::<Vec<_>>()
696 }),
697 imatrix: None,
698 calibration_file: None,
699 hf_cache_path,
700 },
701 args.chat_template,
702 args.tokenizer_json,
703 model_id,
704 args.no_kv_cache,
705 args.jinja_explicit,
706 )
707 .with_lora(
708 adapter_model_ids
709 .split(MULTI_LORA_DELIMITER)
710 .map(ToString::to_string)
711 .collect(),
712 )
713 .build(arch)?,
714 TomlModelSelected::GGUF {
715 tok_model_id,
716 quantized_model_id,
717 quantized_filename,
718 topology,
719 dtype: _,
720 max_seq_len: _,
721 max_batch_size: _,
722 } => GGUFLoaderBuilder::new(
723 args.chat_template,
724 Some(tok_model_id),
725 quantized_model_id,
726 quantized_filename
727 .split(GGUF_MULTI_FILE_DELIMITER)
728 .map(ToOwned::to_owned)
729 .collect::<Vec<_>>(),
730 GGUFSpecificConfig {
731 prompt_chunksize: args.prompt_chunksize,
732 topology: Topology::from_option_path(topology)?,
733 },
734 args.no_kv_cache,
735 args.jinja_explicit,
736 )
737 .build(),
738 TomlModelSelected::XLoraGGUF {
739 tok_model_id,
740 quantized_model_id,
741 quantized_filename,
742 xlora_model_id,
743 order,
744 tgt_non_granular_index,
745 topology,
746 dtype: _,
747 max_seq_len: _,
748 max_batch_size: _,
749 } => GGUFLoaderBuilder::new(
750 args.chat_template,
751 tok_model_id,
752 quantized_model_id,
753 quantized_filename
754 .split(GGUF_MULTI_FILE_DELIMITER)
755 .map(ToOwned::to_owned)
756 .collect::<Vec<_>>(),
757 GGUFSpecificConfig {
758 prompt_chunksize: args.prompt_chunksize,
759 topology: Topology::from_option_path(topology)?,
760 },
761 args.no_kv_cache,
762 args.jinja_explicit,
763 )
764 .with_xlora(
765 xlora_model_id,
766 serde_json::from_reader(
767 File::open(order.clone())
768 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
769 )?,
770 args.no_kv_cache,
771 tgt_non_granular_index,
772 )
773 .build(),
774 TomlModelSelected::LoraGGUF {
775 tok_model_id,
776 quantized_model_id,
777 quantized_filename,
778 adapters_model_id,
779 order,
780 topology,
781 ..
782 } => GGUFLoaderBuilder::new(
783 args.chat_template,
784 tok_model_id,
785 quantized_model_id,
786 quantized_filename
787 .split(GGUF_MULTI_FILE_DELIMITER)
788 .map(ToOwned::to_owned)
789 .collect::<Vec<_>>(),
790 GGUFSpecificConfig {
791 prompt_chunksize: args.prompt_chunksize,
792 topology: Topology::from_option_path(topology)?,
793 },
794 args.no_kv_cache,
795 args.jinja_explicit,
796 )
797 .with_lora(
798 adapters_model_id,
799 serde_json::from_reader(
800 File::open(order.clone())
801 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
802 )?,
803 )
804 .build(),
805 TomlModelSelected::GGML {
806 tok_model_id,
807 quantized_model_id,
808 quantized_filename,
809 gqa,
810 topology,
811 dtype: _,
812 max_seq_len: _,
813 max_batch_size: _,
814 } => GGMLLoaderBuilder::new(
815 GGMLSpecificConfig {
816 gqa,
817 prompt_chunksize: args.prompt_chunksize,
818 topology: Topology::from_option_path(topology)?,
819 },
820 args.chat_template,
821 args.tokenizer_json,
822 Some(tok_model_id),
823 quantized_model_id,
824 quantized_filename,
825 args.no_kv_cache,
826 args.jinja_explicit,
827 )
828 .build(),
829 TomlModelSelected::XLoraGGML {
830 tok_model_id,
831 quantized_model_id,
832 quantized_filename,
833 xlora_model_id,
834 order,
835 tgt_non_granular_index,
836 gqa,
837 topology,
838 dtype: _,
839 max_seq_len: _,
840 max_batch_size: _,
841 } => GGMLLoaderBuilder::new(
842 GGMLSpecificConfig {
843 gqa,
844 prompt_chunksize: args.prompt_chunksize,
845 topology: Topology::from_option_path(topology)?,
846 },
847 args.chat_template,
848 args.tokenizer_json,
849 tok_model_id,
850 quantized_model_id,
851 quantized_filename,
852 args.no_kv_cache,
853 args.jinja_explicit,
854 )
855 .with_xlora(
856 xlora_model_id,
857 serde_json::from_reader(
858 File::open(order.clone())
859 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
860 )?,
861 args.no_kv_cache,
862 tgt_non_granular_index,
863 )
864 .build(),
865 TomlModelSelected::LoraGGML {
866 tok_model_id,
867 quantized_model_id,
868 quantized_filename,
869 adapters_model_id,
870 order,
871 gqa,
872 topology,
873 dtype: _,
874 max_seq_len: _,
875 max_batch_size: _,
876 } => GGMLLoaderBuilder::new(
877 GGMLSpecificConfig {
878 gqa,
879 prompt_chunksize: args.prompt_chunksize,
880 topology: Topology::from_option_path(topology)?,
881 },
882 args.chat_template,
883 args.tokenizer_json,
884 tok_model_id,
885 quantized_model_id,
886 quantized_filename,
887 args.no_kv_cache,
888 args.jinja_explicit,
889 )
890 .with_lora(
891 adapters_model_id,
892 serde_json::from_reader(
893 File::open(order.clone())
894 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
895 )?,
896 )
897 .build(),
898 TomlModelSelected::VisionPlain {
899 model_id,
900 arch,
901 dtype: _,
902 topology,
903 write_uqff,
904 from_uqff,
905 max_edge,
906 calibration_file,
907 max_seq_len: _,
908 max_batch_size: _,
909 max_num_images: _,
910 max_image_length: _,
911 imatrix,
912 hf_cache_path,
913 } => VisionLoaderBuilder::new(
914 VisionSpecificConfig {
915 prompt_chunksize: args.prompt_chunksize,
916 topology: Topology::from_option_path(topology)?,
917 write_uqff,
918 from_uqff: from_uqff.map(|x| {
919 x.split(UQFF_MULTI_FILE_DELIMITER)
920 .map(PathBuf::from_str)
921 .map(|x| x.unwrap())
922 .collect::<Vec<_>>()
923 }),
924 max_edge,
925 calibration_file,
926 imatrix,
927 hf_cache_path,
928 },
929 args.chat_template,
930 args.tokenizer_json,
931 Some(model_id),
932 args.jinja_explicit,
933 )
934 .build(arch),
935 };
936 Ok(loader)
937}
938
939impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
940 type Error = anyhow::Error;
941 fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
942 let (selector, args) = self;
943 let args = TomlLoaderInnerParams {
944 chat_template: args.chat_template,
945 no_kv_cache: args.no_kv_cache,
946 tokenizer_json: selector.tokenizer_json,
947 prompt_chunksize: args.prompt_chunksize,
948 jinja_explicit: args.jinja_explicit,
949 };
950 let loader = loader_from_selected(args.clone(), selector.model)?;
951 let loader = if let Some(speculative) = selector.speculative {
952 let draft_loader = loader_from_selected(args, speculative.draft_model)?;
953 Box::new(SpeculativeLoader {
954 target: loader,
955 draft: draft_loader,
956 config: SpeculativeConfig {
957 gamma: speculative.gamma,
958 },
959 })
960 } else {
961 loader
962 };
963 let loader = if let Some(AnyMoeTomlModelSelected {
964 config,
965 dataset_json,
966 prefix,
967 mlp,
968 model_ids,
969 layers,
970 }) = selector.anymoe
971 {
972 Box::new(AnyMoeLoader {
973 target: loader,
974 config,
975 path: dataset_json,
976 prefix,
977 mlp,
978 model_ids,
979 layers,
980 })
981 } else {
982 loader
983 };
984 Ok(loader)
985 }
986}