1use std::{fs::File, num::NonZeroUsize, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7 amoe::AnyMoeConfig, pipeline::IsqOrganization, AnyMoeLoader, AutoDeviceMapParams,
8 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
9 ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
10 SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
11 GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
12};
13
14fn default_one() -> usize {
15 1
16}
17
18fn default_dtype() -> ModelDType {
19 ModelDType::Auto
20}
21
22fn default_empty_vec_usize() -> Vec<usize> {
23 Vec::new()
24}
25
26fn default_max_seq_len() -> usize {
27 AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
28}
29
30fn default_max_batch_size() -> usize {
31 AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
32}
33
34fn default_max_num_images() -> usize {
35 AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
36}
37
38fn default_max_image_length() -> usize {
39 AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44pub enum TomlModelSelected {
45 Plain {
47 model_id: String,
49
50 arch: Option<NormalLoaderType>,
52
53 #[serde(default = "default_dtype")]
55 dtype: ModelDType,
56
57 topology: Option<String>,
59
60 organization: Option<IsqOrganization>,
62
63 write_uqff: Option<PathBuf>,
65
66 from_uqff: Option<String>,
68
69 imatrix: Option<PathBuf>,
72
73 calibration_file: Option<PathBuf>,
76
77 #[serde(default = "default_max_seq_len")]
79 max_seq_len: usize,
80
81 #[serde(default = "default_max_batch_size")]
83 max_batch_size: usize,
84
85 hf_cache_path: Option<PathBuf>,
87 },
88
89 XLora {
91 model_id: Option<String>,
93
94 xlora_model_id: String,
96
97 order: String,
99
100 tgt_non_granular_index: Option<usize>,
103
104 arch: Option<NormalLoaderType>,
106
107 #[serde(default = "default_dtype")]
109 dtype: ModelDType,
110
111 topology: Option<String>,
113
114 write_uqff: Option<PathBuf>,
116
117 from_uqff: Option<String>,
119
120 #[serde(default = "default_max_seq_len")]
122 max_seq_len: usize,
123
124 #[serde(default = "default_max_batch_size")]
126 max_batch_size: usize,
127
128 hf_cache_path: Option<PathBuf>,
130 },
131
132 Lora {
134 model_id: Option<String>,
136
137 adapter_model_ids: String,
139
140 arch: Option<NormalLoaderType>,
142
143 #[serde(default = "default_dtype")]
145 dtype: ModelDType,
146
147 topology: Option<String>,
149
150 write_uqff: Option<PathBuf>,
152
153 from_uqff: Option<String>,
155
156 #[serde(default = "default_max_seq_len")]
158 max_seq_len: usize,
159
160 #[serde(default = "default_max_batch_size")]
162 max_batch_size: usize,
163
164 hf_cache_path: Option<PathBuf>,
166 },
167
168 #[allow(clippy::upper_case_acronyms)]
170 GGUF {
171 tok_model_id: String,
175
176 quantized_model_id: String,
179
180 quantized_filename: String,
183
184 #[serde(default = "default_dtype")]
186 dtype: ModelDType,
187
188 topology: Option<String>,
190
191 #[serde(default = "default_max_seq_len")]
193 max_seq_len: usize,
194
195 #[serde(default = "default_max_batch_size")]
197 max_batch_size: usize,
198 },
199
200 XLoraGGUF {
202 tok_model_id: Option<String>,
206
207 quantized_model_id: String,
210
211 quantized_filename: String,
214
215 xlora_model_id: String,
217
218 order: String,
220
221 tgt_non_granular_index: Option<usize>,
224
225 #[serde(default = "default_dtype")]
227 dtype: ModelDType,
228
229 topology: Option<String>,
231
232 #[serde(default = "default_max_seq_len")]
234 max_seq_len: usize,
235
236 #[serde(default = "default_max_batch_size")]
238 max_batch_size: usize,
239 },
240
241 LoraGGUF {
243 tok_model_id: Option<String>,
247
248 quantized_model_id: String,
251
252 quantized_filename: String,
255
256 adapters_model_id: String,
258
259 order: String,
261
262 #[serde(default = "default_dtype")]
264 dtype: ModelDType,
265
266 topology: Option<String>,
268
269 #[serde(default = "default_max_seq_len")]
271 max_seq_len: usize,
272
273 #[serde(default = "default_max_batch_size")]
275 max_batch_size: usize,
276 },
277
278 #[allow(clippy::upper_case_acronyms)]
280 GGML {
281 tok_model_id: String,
283
284 quantized_model_id: String,
287
288 quantized_filename: String,
290
291 #[serde(default = "default_one")]
293 gqa: usize,
294
295 #[serde(default = "default_dtype")]
297 dtype: ModelDType,
298
299 topology: Option<String>,
301
302 #[serde(default = "default_max_seq_len")]
304 max_seq_len: usize,
305
306 #[serde(default = "default_max_batch_size")]
308 max_batch_size: usize,
309 },
310
311 XLoraGGML {
313 tok_model_id: Option<String>,
315
316 quantized_model_id: String,
319
320 quantized_filename: String,
322
323 xlora_model_id: String,
325
326 order: String,
328
329 tgt_non_granular_index: Option<usize>,
332
333 #[serde(default = "default_one")]
335 gqa: usize,
336
337 #[serde(default = "default_dtype")]
339 dtype: ModelDType,
340
341 topology: Option<String>,
343
344 #[serde(default = "default_max_seq_len")]
346 max_seq_len: usize,
347
348 #[serde(default = "default_max_batch_size")]
350 max_batch_size: usize,
351 },
352
353 LoraGGML {
355 tok_model_id: Option<String>,
357
358 quantized_model_id: String,
361
362 quantized_filename: String,
364
365 adapters_model_id: String,
367
368 order: String,
370
371 #[serde(default = "default_one")]
373 gqa: usize,
374
375 #[serde(default = "default_dtype")]
377 dtype: ModelDType,
378
379 topology: Option<String>,
381
382 #[serde(default = "default_max_seq_len")]
384 max_seq_len: usize,
385
386 #[serde(default = "default_max_batch_size")]
388 max_batch_size: usize,
389 },
390
391 VisionPlain {
393 model_id: String,
395
396 arch: VisionLoaderType,
398
399 #[serde(default = "default_dtype")]
401 dtype: ModelDType,
402
403 topology: Option<String>,
405
406 write_uqff: Option<PathBuf>,
408
409 from_uqff: Option<String>,
411
412 max_edge: Option<u32>,
415
416 calibration_file: Option<PathBuf>,
418
419 imatrix: Option<PathBuf>,
421
422 #[serde(default = "default_max_seq_len")]
424 max_seq_len: usize,
425
426 #[serde(default = "default_max_batch_size")]
428 max_batch_size: usize,
429
430 #[serde(default = "default_max_num_images")]
432 max_num_images: usize,
433
434 #[serde(default = "default_max_image_length")]
437 max_image_length: usize,
438
439 hf_cache_path: Option<PathBuf>,
441 },
442}
443
444#[derive(Deserialize)]
445pub struct SpeculativeTomlModelSelected {
446 gamma: usize,
448
449 draft_model: TomlModelSelected,
451}
452
453#[derive(Deserialize)]
454pub struct AnyMoeTomlModelSelected {
455 config: AnyMoeConfig,
457
458 dataset_json: String,
460
461 prefix: String,
463
464 mlp: String,
466
467 model_ids: Vec<String>,
469
470 #[serde(default = "default_empty_vec_usize")]
472 layers: Vec<usize>,
473}
474
475#[derive(Deserialize)]
476pub struct TomlSelector {
477 tokenizer_json: Option<String>,
479
480 model: TomlModelSelected,
482
483 speculative: Option<SpeculativeTomlModelSelected>,
485
486 anymoe: Option<AnyMoeTomlModelSelected>,
488}
489
490#[derive(Clone)]
491struct TomlLoaderInnerParams {
492 use_flash_attn: bool,
493 chat_template: Option<String>,
494 no_kv_cache: bool,
495 tokenizer_json: Option<String>,
496 prompt_chunksize: Option<NonZeroUsize>,
497 jinja_explicit: Option<String>,
498}
499
500pub struct TomlLoaderArgs {
501 pub use_flash_attn: bool,
502 pub chat_template: Option<String>,
503 pub no_kv_cache: bool,
504 pub prompt_chunksize: Option<NonZeroUsize>,
505 pub jinja_explicit: Option<String>,
506}
507
508pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
509 match model.model {
510 TomlModelSelected::Plain { dtype, .. }
511 | TomlModelSelected::Lora { dtype, .. }
512 | TomlModelSelected::XLora { dtype, .. }
513 | TomlModelSelected::VisionPlain { dtype, .. }
514 | TomlModelSelected::GGUF { dtype, .. }
515 | TomlModelSelected::GGML { dtype, .. }
516 | TomlModelSelected::XLoraGGUF { dtype, .. }
517 | TomlModelSelected::XLoraGGML { dtype, .. }
518 | TomlModelSelected::LoraGGUF { dtype, .. }
519 | TomlModelSelected::LoraGGML { dtype, .. } => dtype,
520 }
521}
522
523pub fn get_toml_selected_model_device_map_params(
524 model: &TomlSelector,
525) -> anyhow::Result<AutoDeviceMapParams> {
526 match model.model {
527 TomlModelSelected::Plain {
528 max_seq_len,
529 max_batch_size,
530 ..
531 }
532 | TomlModelSelected::Lora {
533 max_seq_len,
534 max_batch_size,
535 ..
536 }
537 | TomlModelSelected::XLora {
538 max_seq_len,
539 max_batch_size,
540 ..
541 }
542 | TomlModelSelected::GGML {
543 max_seq_len,
544 max_batch_size,
545 ..
546 }
547 | TomlModelSelected::GGUF {
548 max_seq_len,
549 max_batch_size,
550 ..
551 }
552 | TomlModelSelected::XLoraGGUF {
553 max_seq_len,
554 max_batch_size,
555 ..
556 }
557 | TomlModelSelected::XLoraGGML {
558 max_seq_len,
559 max_batch_size,
560 ..
561 }
562 | TomlModelSelected::LoraGGUF {
563 max_seq_len,
564 max_batch_size,
565 ..
566 }
567 | TomlModelSelected::LoraGGML {
568 max_seq_len,
569 max_batch_size,
570 ..
571 } => Ok(AutoDeviceMapParams::Text {
572 max_seq_len,
573 max_batch_size,
574 }),
575 TomlModelSelected::VisionPlain {
576 max_seq_len,
577 max_batch_size,
578 max_image_length,
579 max_num_images,
580 ..
581 } => Ok(AutoDeviceMapParams::Vision {
582 max_seq_len,
583 max_batch_size,
584 max_image_shape: (max_image_length, max_image_length),
585 max_num_images,
586 }),
587 }
588}
589
590fn loader_from_selected(
591 args: TomlLoaderInnerParams,
592 model: TomlModelSelected,
593) -> anyhow::Result<Box<dyn Loader>> {
594 let use_flash_attn = args.use_flash_attn;
595 let loader: Box<dyn Loader> = match model {
596 TomlModelSelected::Plain {
597 model_id,
598 arch,
599 dtype: _,
600 topology,
601 organization,
602 write_uqff,
603 from_uqff,
604 imatrix,
605 calibration_file,
606 max_seq_len: _,
607 max_batch_size: _,
608 hf_cache_path,
609 } => NormalLoaderBuilder::new(
610 NormalSpecificConfig {
611 use_flash_attn,
612 prompt_chunksize: args.prompt_chunksize,
613 topology: Topology::from_option_path(topology)?,
614 organization: organization.unwrap_or_default(),
615 write_uqff,
616 from_uqff: from_uqff.map(|x| {
617 x.split(UQFF_MULTI_FILE_DELIMITER)
618 .map(PathBuf::from_str)
619 .map(|x| x.unwrap())
620 .collect::<Vec<_>>()
621 }),
622 imatrix,
623 calibration_file,
624 hf_cache_path,
625 },
626 args.chat_template,
627 args.tokenizer_json,
628 Some(model_id),
629 args.no_kv_cache,
630 args.jinja_explicit,
631 )
632 .build(arch)?,
633 TomlModelSelected::XLora {
634 model_id,
635 xlora_model_id,
636 order,
637 tgt_non_granular_index,
638 arch,
639 dtype: _,
640 topology,
641 write_uqff,
642 from_uqff,
643 max_seq_len: _,
644 max_batch_size: _,
645 hf_cache_path,
646 } => NormalLoaderBuilder::new(
647 NormalSpecificConfig {
648 use_flash_attn,
649 prompt_chunksize: args.prompt_chunksize,
650 topology: Topology::from_option_path(topology)?,
651 organization: Default::default(),
652 write_uqff,
653 from_uqff: from_uqff.map(|x| {
654 x.split(UQFF_MULTI_FILE_DELIMITER)
655 .map(PathBuf::from_str)
656 .map(|x| x.unwrap())
657 .collect::<Vec<_>>()
658 }),
659 imatrix: None,
660 calibration_file: None,
661 hf_cache_path,
662 },
663 args.chat_template,
664 args.tokenizer_json,
665 model_id,
666 args.no_kv_cache,
667 args.jinja_explicit,
668 )
669 .with_xlora(
670 xlora_model_id,
671 serde_json::from_reader(
672 File::open(order.clone())
673 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
674 )?,
675 args.no_kv_cache,
676 tgt_non_granular_index,
677 )
678 .build(arch)?,
679 TomlModelSelected::Lora {
680 model_id,
681 adapter_model_ids,
682 arch,
683 dtype: _,
684 topology,
685 write_uqff,
686 from_uqff,
687 max_seq_len: _,
688 max_batch_size: _,
689 hf_cache_path,
690 } => NormalLoaderBuilder::new(
691 NormalSpecificConfig {
692 use_flash_attn,
693 prompt_chunksize: args.prompt_chunksize,
694 topology: Topology::from_option_path(topology)?,
695 organization: Default::default(),
696 write_uqff,
697 from_uqff: from_uqff.map(|x| {
698 x.split(UQFF_MULTI_FILE_DELIMITER)
699 .map(PathBuf::from_str)
700 .map(|x| x.unwrap())
701 .collect::<Vec<_>>()
702 }),
703 imatrix: None,
704 calibration_file: None,
705 hf_cache_path,
706 },
707 args.chat_template,
708 args.tokenizer_json,
709 model_id,
710 args.no_kv_cache,
711 args.jinja_explicit,
712 )
713 .with_lora(
714 adapter_model_ids
715 .split(MULTI_LORA_DELIMITER)
716 .map(ToString::to_string)
717 .collect(),
718 )
719 .build(arch)?,
720 TomlModelSelected::GGUF {
721 tok_model_id,
722 quantized_model_id,
723 quantized_filename,
724 topology,
725 dtype: _,
726 max_seq_len: _,
727 max_batch_size: _,
728 } => GGUFLoaderBuilder::new(
729 args.chat_template,
730 Some(tok_model_id),
731 quantized_model_id,
732 quantized_filename
733 .split(GGUF_MULTI_FILE_DELIMITER)
734 .map(ToOwned::to_owned)
735 .collect::<Vec<_>>(),
736 GGUFSpecificConfig {
737 prompt_chunksize: args.prompt_chunksize,
738 topology: Topology::from_option_path(topology)?,
739 },
740 args.no_kv_cache,
741 args.jinja_explicit,
742 )
743 .build(),
744 TomlModelSelected::XLoraGGUF {
745 tok_model_id,
746 quantized_model_id,
747 quantized_filename,
748 xlora_model_id,
749 order,
750 tgt_non_granular_index,
751 topology,
752 dtype: _,
753 max_seq_len: _,
754 max_batch_size: _,
755 } => GGUFLoaderBuilder::new(
756 args.chat_template,
757 tok_model_id,
758 quantized_model_id,
759 quantized_filename
760 .split(GGUF_MULTI_FILE_DELIMITER)
761 .map(ToOwned::to_owned)
762 .collect::<Vec<_>>(),
763 GGUFSpecificConfig {
764 prompt_chunksize: args.prompt_chunksize,
765 topology: Topology::from_option_path(topology)?,
766 },
767 args.no_kv_cache,
768 args.jinja_explicit,
769 )
770 .with_xlora(
771 xlora_model_id,
772 serde_json::from_reader(
773 File::open(order.clone())
774 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
775 )?,
776 args.no_kv_cache,
777 tgt_non_granular_index,
778 )
779 .build(),
780 TomlModelSelected::LoraGGUF {
781 tok_model_id,
782 quantized_model_id,
783 quantized_filename,
784 adapters_model_id,
785 order,
786 topology,
787 ..
788 } => GGUFLoaderBuilder::new(
789 args.chat_template,
790 tok_model_id,
791 quantized_model_id,
792 quantized_filename
793 .split(GGUF_MULTI_FILE_DELIMITER)
794 .map(ToOwned::to_owned)
795 .collect::<Vec<_>>(),
796 GGUFSpecificConfig {
797 prompt_chunksize: args.prompt_chunksize,
798 topology: Topology::from_option_path(topology)?,
799 },
800 args.no_kv_cache,
801 args.jinja_explicit,
802 )
803 .with_lora(
804 adapters_model_id,
805 serde_json::from_reader(
806 File::open(order.clone())
807 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
808 )?,
809 )
810 .build(),
811 TomlModelSelected::GGML {
812 tok_model_id,
813 quantized_model_id,
814 quantized_filename,
815 gqa,
816 topology,
817 dtype: _,
818 max_seq_len: _,
819 max_batch_size: _,
820 } => GGMLLoaderBuilder::new(
821 GGMLSpecificConfig {
822 gqa,
823 prompt_chunksize: args.prompt_chunksize,
824 topology: Topology::from_option_path(topology)?,
825 },
826 args.chat_template,
827 args.tokenizer_json,
828 Some(tok_model_id),
829 quantized_model_id,
830 quantized_filename,
831 args.no_kv_cache,
832 args.jinja_explicit,
833 )
834 .build(),
835 TomlModelSelected::XLoraGGML {
836 tok_model_id,
837 quantized_model_id,
838 quantized_filename,
839 xlora_model_id,
840 order,
841 tgt_non_granular_index,
842 gqa,
843 topology,
844 dtype: _,
845 max_seq_len: _,
846 max_batch_size: _,
847 } => GGMLLoaderBuilder::new(
848 GGMLSpecificConfig {
849 gqa,
850 prompt_chunksize: args.prompt_chunksize,
851 topology: Topology::from_option_path(topology)?,
852 },
853 args.chat_template,
854 args.tokenizer_json,
855 tok_model_id,
856 quantized_model_id,
857 quantized_filename,
858 args.no_kv_cache,
859 args.jinja_explicit,
860 )
861 .with_xlora(
862 xlora_model_id,
863 serde_json::from_reader(
864 File::open(order.clone())
865 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
866 )?,
867 args.no_kv_cache,
868 tgt_non_granular_index,
869 )
870 .build(),
871 TomlModelSelected::LoraGGML {
872 tok_model_id,
873 quantized_model_id,
874 quantized_filename,
875 adapters_model_id,
876 order,
877 gqa,
878 topology,
879 dtype: _,
880 max_seq_len: _,
881 max_batch_size: _,
882 } => GGMLLoaderBuilder::new(
883 GGMLSpecificConfig {
884 gqa,
885 prompt_chunksize: args.prompt_chunksize,
886 topology: Topology::from_option_path(topology)?,
887 },
888 args.chat_template,
889 args.tokenizer_json,
890 tok_model_id,
891 quantized_model_id,
892 quantized_filename,
893 args.no_kv_cache,
894 args.jinja_explicit,
895 )
896 .with_lora(
897 adapters_model_id,
898 serde_json::from_reader(
899 File::open(order.clone())
900 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
901 )?,
902 )
903 .build(),
904 TomlModelSelected::VisionPlain {
905 model_id,
906 arch,
907 dtype: _,
908 topology,
909 write_uqff,
910 from_uqff,
911 max_edge,
912 calibration_file,
913 max_seq_len: _,
914 max_batch_size: _,
915 max_num_images: _,
916 max_image_length: _,
917 imatrix,
918 hf_cache_path,
919 } => VisionLoaderBuilder::new(
920 VisionSpecificConfig {
921 use_flash_attn,
922 prompt_chunksize: args.prompt_chunksize,
923 topology: Topology::from_option_path(topology)?,
924 write_uqff,
925 from_uqff: from_uqff.map(|x| {
926 x.split(UQFF_MULTI_FILE_DELIMITER)
927 .map(PathBuf::from_str)
928 .map(|x| x.unwrap())
929 .collect::<Vec<_>>()
930 }),
931 max_edge,
932 calibration_file,
933 imatrix,
934 hf_cache_path,
935 },
936 args.chat_template,
937 args.tokenizer_json,
938 Some(model_id),
939 args.jinja_explicit,
940 )
941 .build(arch),
942 };
943 Ok(loader)
944}
945
946impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
947 type Error = anyhow::Error;
948 fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
949 let (selector, args) = self;
950 let args = TomlLoaderInnerParams {
951 use_flash_attn: args.use_flash_attn,
952 chat_template: args.chat_template,
953 no_kv_cache: args.no_kv_cache,
954 tokenizer_json: selector.tokenizer_json,
955 prompt_chunksize: args.prompt_chunksize,
956 jinja_explicit: args.jinja_explicit,
957 };
958 let loader = loader_from_selected(args.clone(), selector.model)?;
959 let loader = if let Some(speculative) = selector.speculative {
960 let draft_loader = loader_from_selected(args, speculative.draft_model)?;
961 Box::new(SpeculativeLoader {
962 target: loader,
963 draft: draft_loader,
964 config: SpeculativeConfig {
965 gamma: speculative.gamma,
966 },
967 })
968 } else {
969 loader
970 };
971 let loader = if let Some(AnyMoeTomlModelSelected {
972 config,
973 dataset_json,
974 prefix,
975 mlp,
976 model_ids,
977 layers,
978 }) = selector.anymoe
979 {
980 Box::new(AnyMoeLoader {
981 target: loader,
982 config,
983 path: dataset_json,
984 prefix,
985 mlp,
986 model_ids,
987 layers,
988 })
989 } else {
990 loader
991 };
992 Ok(loader)
993 }
994}