mistralrs_core/
toml_selector.rs

1use std::{fs::File, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7    amoe::AnyMoeConfig,
8    pipeline::{EmbeddingLoaderType, IsqOrganization},
9    AnyMoeLoader, AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig,
10    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
11    ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
12    SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
13    GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
14};
15
16fn default_one() -> usize {
17    1
18}
19
20fn default_dtype() -> ModelDType {
21    ModelDType::Auto
22}
23
24fn default_empty_vec_usize() -> Vec<usize> {
25    Vec::new()
26}
27
28fn default_max_seq_len() -> usize {
29    AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
30}
31
32fn default_max_batch_size() -> usize {
33    AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
34}
35
36fn default_max_num_images() -> usize {
37    AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
38}
39
40fn default_max_image_length() -> usize {
41    AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(untagged)]
46pub enum TomlModelSelected {
47    /// Select a plain model, without quantization or adapters
48    Plain {
49        /// Model ID to load from. This may be a HF hub repo or a local path.
50        model_id: String,
51
52        /// The architecture of the model.
53        arch: Option<NormalLoaderType>,
54
55        /// Model data type. Defaults to `auto`.
56        #[serde(default = "default_dtype")]
57        dtype: ModelDType,
58
59        /// Path to a topology YAML file.
60        topology: Option<String>,
61
62        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
63        organization: Option<IsqOrganization>,
64
65        /// UQFF path to write to.
66        write_uqff: Option<PathBuf>,
67
68        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
69        from_uqff: Option<String>,
70
71        /// .imatrix file to enhance GGUF quantizations with.
72        /// Incompatible with `--imatrix/-i`
73        imatrix: Option<PathBuf>,
74
75        /// Generate and utilize an imatrix to enhance GGUF quantizations.
76        /// Incompatible with `--imatrix/-i`
77        calibration_file: Option<PathBuf>,
78
79        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
80        #[serde(default = "default_max_seq_len")]
81        max_seq_len: usize,
82
83        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
84        #[serde(default = "default_max_batch_size")]
85        max_batch_size: usize,
86
87        /// Cache path for Hugging Face models downloaded locally
88        hf_cache_path: Option<PathBuf>,
89    },
90
91    /// Select an X-LoRA architecture
92    XLora {
93        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
94        model_id: Option<String>,
95
96        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
97        xlora_model_id: String,
98
99        /// Ordering JSON file
100        order: String,
101
102        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
103        /// This makes the maximum running sequences 1.
104        tgt_non_granular_index: Option<usize>,
105
106        /// The architecture of the model.
107        arch: Option<NormalLoaderType>,
108
109        /// Model data type. Defaults to `auto`.
110        #[serde(default = "default_dtype")]
111        dtype: ModelDType,
112
113        /// Path to a topology YAML file.
114        topology: Option<String>,
115
116        /// UQFF path to write to.
117        write_uqff: Option<PathBuf>,
118
119        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
120        from_uqff: Option<String>,
121
122        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
123        #[serde(default = "default_max_seq_len")]
124        max_seq_len: usize,
125
126        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
127        #[serde(default = "default_max_batch_size")]
128        max_batch_size: usize,
129
130        /// Cache path for Hugging Face models downloaded locally
131        hf_cache_path: Option<PathBuf>,
132    },
133
134    /// Select a LoRA architecture
135    Lora {
136        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
137        model_id: Option<String>,
138
139        /// Model IDs to load LoRA from. This may be a HF hub repo or a local path. Specify multiple with a semicolon.
140        adapter_model_ids: String,
141
142        /// The architecture of the model.
143        arch: Option<NormalLoaderType>,
144
145        /// Model data type. Defaults to `auto`.
146        #[serde(default = "default_dtype")]
147        dtype: ModelDType,
148
149        /// Path to a topology YAML file.
150        topology: Option<String>,
151
152        /// UQFF path to write to.
153        write_uqff: Option<PathBuf>,
154
155        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
156        from_uqff: Option<String>,
157
158        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
159        #[serde(default = "default_max_seq_len")]
160        max_seq_len: usize,
161
162        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
163        #[serde(default = "default_max_batch_size")]
164        max_batch_size: usize,
165
166        /// Cache path for Hugging Face models downloaded locally
167        hf_cache_path: Option<PathBuf>,
168    },
169
170    /// Select a GGUF model.
171    #[allow(clippy::upper_case_acronyms)]
172    GGUF {
173        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
174        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
175        /// removing all remote accesses.
176        tok_model_id: String,
177
178        /// Quantized model ID to find the `quantized_filename`.
179        /// This may be a HF hub repo or a local path.
180        quantized_model_id: String,
181
182        /// Quantized filename(s).
183        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
184        quantized_filename: String,
185
186        /// Model data type. Defaults to `auto`.
187        #[serde(default = "default_dtype")]
188        dtype: ModelDType,
189
190        /// Path to a topology YAML file.
191        topology: Option<String>,
192
193        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
194        #[serde(default = "default_max_seq_len")]
195        max_seq_len: usize,
196
197        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
198        #[serde(default = "default_max_batch_size")]
199        max_batch_size: usize,
200    },
201
202    /// Select a GGUF model with X-LoRA.
203    XLoraGGUF {
204        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
205        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
206        /// removing all remote accesses.
207        tok_model_id: Option<String>,
208
209        /// Quantized model ID to find the `quantized_filename`.
210        /// This may be a HF hub repo or a local path.
211        quantized_model_id: String,
212
213        /// Quantized filename(s).
214        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
215        quantized_filename: String,
216
217        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
218        xlora_model_id: String,
219
220        /// Ordering JSON file
221        order: String,
222
223        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
224        /// This makes the maximum running sequences 1.
225        tgt_non_granular_index: Option<usize>,
226
227        /// Model data type. Defaults to `auto`.
228        #[serde(default = "default_dtype")]
229        dtype: ModelDType,
230
231        /// Path to a topology YAML file.
232        topology: Option<String>,
233
234        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
235        #[serde(default = "default_max_seq_len")]
236        max_seq_len: usize,
237
238        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
239        #[serde(default = "default_max_batch_size")]
240        max_batch_size: usize,
241    },
242
243    /// Select a GGUF model with LoRA.
244    LoraGGUF {
245        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
246        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
247        /// removing all remote accesses.
248        tok_model_id: Option<String>,
249
250        /// Quantized model ID to find the `quantized_filename`.
251        /// This may be a HF hub repo or a local path.
252        quantized_model_id: String,
253
254        /// Quantized filename(s).
255        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
256        quantized_filename: String,
257
258        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
259        adapters_model_id: String,
260
261        /// Ordering JSON file
262        order: String,
263
264        /// Model data type. Defaults to `auto`.
265        #[serde(default = "default_dtype")]
266        dtype: ModelDType,
267
268        /// Path to a topology YAML file.
269        topology: Option<String>,
270
271        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
272        #[serde(default = "default_max_seq_len")]
273        max_seq_len: usize,
274
275        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
276        #[serde(default = "default_max_batch_size")]
277        max_batch_size: usize,
278    },
279
280    /// Select a GGML model.
281    #[allow(clippy::upper_case_acronyms)]
282    GGML {
283        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
284        tok_model_id: String,
285
286        /// Quantized model ID to find the `quantized_filename`.
287        /// This may be a HF hub repo or a local path.
288        quantized_model_id: String,
289
290        /// Quantized filename.
291        quantized_filename: String,
292
293        /// GQA value
294        #[serde(default = "default_one")]
295        gqa: usize,
296
297        /// Model data type. Defaults to `auto`.
298        #[serde(default = "default_dtype")]
299        dtype: ModelDType,
300
301        /// Path to a topology YAML file.
302        topology: Option<String>,
303
304        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
305        #[serde(default = "default_max_seq_len")]
306        max_seq_len: usize,
307
308        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
309        #[serde(default = "default_max_batch_size")]
310        max_batch_size: usize,
311    },
312
313    /// Select a GGML model with X-LoRA.
314    XLoraGGML {
315        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
316        tok_model_id: Option<String>,
317
318        /// Quantized model ID to find the `quantized_filename`.
319        /// This may be a HF hub repo or a local path.
320        quantized_model_id: String,
321
322        /// Quantized filename.
323        quantized_filename: String,
324
325        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
326        xlora_model_id: String,
327
328        /// Ordering JSON file
329        order: String,
330
331        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
332        /// This makes the maximum running sequences 1.
333        tgt_non_granular_index: Option<usize>,
334
335        /// GQA value
336        #[serde(default = "default_one")]
337        gqa: usize,
338
339        /// Model data type. Defaults to `auto`.
340        #[serde(default = "default_dtype")]
341        dtype: ModelDType,
342
343        /// Path to a topology YAML file.
344        topology: Option<String>,
345
346        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
347        #[serde(default = "default_max_seq_len")]
348        max_seq_len: usize,
349
350        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
351        #[serde(default = "default_max_batch_size")]
352        max_batch_size: usize,
353    },
354
355    /// Select a GGML model with LoRA.
356    LoraGGML {
357        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
358        tok_model_id: Option<String>,
359
360        /// Quantized model ID to find the `quantized_filename`.
361        /// This may be a HF hub repo or a local path.
362        quantized_model_id: String,
363
364        /// Quantized filename.
365        quantized_filename: String,
366
367        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
368        adapters_model_id: String,
369
370        /// Ordering JSON file
371        order: String,
372
373        /// GQA value
374        #[serde(default = "default_one")]
375        gqa: usize,
376
377        /// Model data type. Defaults to `auto`.
378        #[serde(default = "default_dtype")]
379        dtype: ModelDType,
380
381        /// Path to a topology YAML file.
382        topology: Option<String>,
383
384        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
385        #[serde(default = "default_max_seq_len")]
386        max_seq_len: usize,
387
388        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
389        #[serde(default = "default_max_batch_size")]
390        max_batch_size: usize,
391    },
392
393    /// Select a vision plain model, without quantization or adapters
394    VisionPlain {
395        /// Model ID to load from. This may be a HF hub repo or a local path.
396        model_id: String,
397
398        /// The architecture of the model.
399        arch: Option<VisionLoaderType>,
400
401        /// Model data type. Defaults to `auto`.
402        #[serde(default = "default_dtype")]
403        dtype: ModelDType,
404
405        /// Path to a topology YAML file.
406        topology: Option<String>,
407
408        /// UQFF path to write to.
409        write_uqff: Option<PathBuf>,
410
411        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
412        from_uqff: Option<String>,
413
414        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
415        /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
416        max_edge: Option<u32>,
417
418        /// Generate and utilize an imatrix to enhance GGUF quantizations.
419        calibration_file: Option<PathBuf>,
420
421        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
422        imatrix: Option<PathBuf>,
423
424        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
425        #[serde(default = "default_max_seq_len")]
426        max_seq_len: usize,
427
428        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
429        #[serde(default = "default_max_batch_size")]
430        max_batch_size: usize,
431
432        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
433        #[serde(default = "default_max_num_images")]
434        max_num_images: usize,
435
436        /// Maximum expected image size will have this edge length on both edges.
437        /// This affects automatic device mapping but is not a hard limit.
438        #[serde(default = "default_max_image_length")]
439        max_image_length: usize,
440
441        /// Cache path for Hugging Face models downloaded locally
442        hf_cache_path: Option<PathBuf>,
443    },
444
445    /// Select an embedding model, without quantization or adapters
446    Embedding {
447        /// Model ID to load from. This may be a HF hub repo or a local path.
448        model_id: String,
449
450        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
451        #[serde(default)]
452        tokenizer_json: Option<String>,
453
454        /// The architecture of the model.
455        #[serde(default)]
456        arch: Option<EmbeddingLoaderType>,
457
458        /// Model data type. Defaults to `auto`.
459        #[serde(default = "default_dtype")]
460        dtype: ModelDType,
461
462        /// Path to a topology YAML file.
463        #[serde(default)]
464        topology: Option<String>,
465
466        /// UQFF path to write to.
467        #[serde(default)]
468        write_uqff: Option<PathBuf>,
469
470        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;)
471        #[serde(default)]
472        from_uqff: Option<String>,
473
474        /// Cache path for Hugging Face models downloaded locally
475        #[serde(default)]
476        hf_cache_path: Option<PathBuf>,
477    },
478}
479
480#[derive(Deserialize)]
481pub struct SpeculativeTomlModelSelected {
482    /// Gamma value for the model
483    gamma: usize,
484
485    /// Base model
486    draft_model: TomlModelSelected,
487}
488
489#[derive(Deserialize)]
490pub struct AnyMoeTomlModelSelected {
491    /// Config
492    config: AnyMoeConfig,
493
494    /// Base model
495    dataset_json: String,
496
497    /// Prefix of the mlp key (the part before the layer number: "a.b.c" in "a.b.c.0.mlp")
498    prefix: String,
499
500    /// Name of the mlp key (the part before the layer number: "mlp" in "a.b.c.0.mlp")
501    mlp: String,
502
503    /// Expert model ids
504    model_ids: Vec<String>,
505
506    /// Layer ids (zero indexed) of layers to apply AnyMoE to, if empty will use all
507    #[serde(default = "default_empty_vec_usize")]
508    layers: Vec<usize>,
509}
510
511#[derive(Deserialize)]
512pub struct TomlSelector {
513    /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
514    tokenizer_json: Option<String>,
515
516    /// Selected model
517    model: TomlModelSelected,
518
519    /// Speculative model selector
520    speculative: Option<SpeculativeTomlModelSelected>,
521
522    /// AnyMoE config
523    anymoe: Option<AnyMoeTomlModelSelected>,
524}
525
526#[derive(Clone)]
527struct TomlLoaderInnerParams {
528    chat_template: Option<String>,
529    no_kv_cache: bool,
530    tokenizer_json: Option<String>,
531    jinja_explicit: Option<String>,
532}
533
534pub struct TomlLoaderArgs {
535    pub chat_template: Option<String>,
536    pub no_kv_cache: bool,
537    pub jinja_explicit: Option<String>,
538}
539
540pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
541    match model.model {
542        TomlModelSelected::Plain { dtype, .. }
543        | TomlModelSelected::Lora { dtype, .. }
544        | TomlModelSelected::XLora { dtype, .. }
545        | TomlModelSelected::VisionPlain { dtype, .. }
546        | TomlModelSelected::GGUF { dtype, .. }
547        | TomlModelSelected::GGML { dtype, .. }
548        | TomlModelSelected::XLoraGGUF { dtype, .. }
549        | TomlModelSelected::XLoraGGML { dtype, .. }
550        | TomlModelSelected::LoraGGUF { dtype, .. }
551        | TomlModelSelected::LoraGGML { dtype, .. }
552        | TomlModelSelected::Embedding { dtype, .. } => dtype,
553    }
554}
555
556pub fn get_toml_selected_model_device_map_params(
557    model: &TomlSelector,
558) -> anyhow::Result<AutoDeviceMapParams> {
559    match model.model {
560        TomlModelSelected::Plain {
561            max_seq_len,
562            max_batch_size,
563            ..
564        }
565        | TomlModelSelected::Lora {
566            max_seq_len,
567            max_batch_size,
568            ..
569        }
570        | TomlModelSelected::XLora {
571            max_seq_len,
572            max_batch_size,
573            ..
574        }
575        | TomlModelSelected::GGML {
576            max_seq_len,
577            max_batch_size,
578            ..
579        }
580        | TomlModelSelected::GGUF {
581            max_seq_len,
582            max_batch_size,
583            ..
584        }
585        | TomlModelSelected::XLoraGGUF {
586            max_seq_len,
587            max_batch_size,
588            ..
589        }
590        | TomlModelSelected::XLoraGGML {
591            max_seq_len,
592            max_batch_size,
593            ..
594        }
595        | TomlModelSelected::LoraGGUF {
596            max_seq_len,
597            max_batch_size,
598            ..
599        }
600        | TomlModelSelected::LoraGGML {
601            max_seq_len,
602            max_batch_size,
603            ..
604        } => Ok(AutoDeviceMapParams::Text {
605            max_seq_len,
606            max_batch_size,
607        }),
608        TomlModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
609        TomlModelSelected::VisionPlain {
610            max_seq_len,
611            max_batch_size,
612            max_image_length,
613            max_num_images,
614            ..
615        } => Ok(AutoDeviceMapParams::Vision {
616            max_seq_len,
617            max_batch_size,
618            max_image_shape: (max_image_length, max_image_length),
619            max_num_images,
620        }),
621    }
622}
623
624fn loader_from_selected(
625    args: TomlLoaderInnerParams,
626    model: TomlModelSelected,
627) -> anyhow::Result<Box<dyn Loader>> {
628    let loader: Box<dyn Loader> = match model {
629        TomlModelSelected::Plain {
630            model_id,
631            arch,
632            dtype: _,
633            topology,
634            organization,
635            write_uqff,
636            from_uqff,
637            imatrix,
638            calibration_file,
639            max_seq_len: _,
640            max_batch_size: _,
641            hf_cache_path,
642        } => NormalLoaderBuilder::new(
643            NormalSpecificConfig {
644                topology: Topology::from_option_path(topology)?,
645                organization: organization.unwrap_or_default(),
646                write_uqff,
647                from_uqff: from_uqff.map(|x| {
648                    x.split(UQFF_MULTI_FILE_DELIMITER)
649                        .map(PathBuf::from_str)
650                        .map(|x| x.unwrap())
651                        .collect::<Vec<_>>()
652                }),
653                imatrix,
654                calibration_file,
655                hf_cache_path,
656                matformer_config_path: None,
657                matformer_slice_name: None,
658            },
659            args.chat_template,
660            args.tokenizer_json,
661            Some(model_id),
662            args.no_kv_cache,
663            args.jinja_explicit,
664        )
665        .build(arch)?,
666        TomlModelSelected::XLora {
667            model_id,
668            xlora_model_id,
669            order,
670            tgt_non_granular_index,
671            arch,
672            dtype: _,
673            topology,
674            write_uqff,
675            from_uqff,
676            max_seq_len: _,
677            max_batch_size: _,
678            hf_cache_path,
679        } => NormalLoaderBuilder::new(
680            NormalSpecificConfig {
681                topology: Topology::from_option_path(topology)?,
682                organization: Default::default(),
683                write_uqff,
684                from_uqff: from_uqff.map(|x| {
685                    x.split(UQFF_MULTI_FILE_DELIMITER)
686                        .map(PathBuf::from_str)
687                        .map(|x| x.unwrap())
688                        .collect::<Vec<_>>()
689                }),
690                imatrix: None,
691                calibration_file: None,
692                hf_cache_path,
693                matformer_config_path: None,
694                matformer_slice_name: None,
695            },
696            args.chat_template,
697            args.tokenizer_json,
698            model_id,
699            args.no_kv_cache,
700            args.jinja_explicit,
701        )
702        .with_xlora(
703            xlora_model_id,
704            serde_json::from_reader(
705                File::open(order.clone())
706                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
707            )?,
708            args.no_kv_cache,
709            tgt_non_granular_index,
710        )
711        .build(arch)?,
712        TomlModelSelected::Lora {
713            model_id,
714            adapter_model_ids,
715            arch,
716            dtype: _,
717            topology,
718            write_uqff,
719            from_uqff,
720            max_seq_len: _,
721            max_batch_size: _,
722            hf_cache_path,
723        } => NormalLoaderBuilder::new(
724            NormalSpecificConfig {
725                topology: Topology::from_option_path(topology)?,
726                organization: Default::default(),
727                write_uqff,
728                from_uqff: from_uqff.map(|x| {
729                    x.split(UQFF_MULTI_FILE_DELIMITER)
730                        .map(PathBuf::from_str)
731                        .map(|x| x.unwrap())
732                        .collect::<Vec<_>>()
733                }),
734                imatrix: None,
735                calibration_file: None,
736                hf_cache_path,
737                matformer_config_path: None,
738                matformer_slice_name: None,
739            },
740            args.chat_template,
741            args.tokenizer_json,
742            model_id,
743            args.no_kv_cache,
744            args.jinja_explicit,
745        )
746        .with_lora(
747            adapter_model_ids
748                .split(MULTI_LORA_DELIMITER)
749                .map(ToString::to_string)
750                .collect(),
751        )
752        .build(arch)?,
753        TomlModelSelected::GGUF {
754            tok_model_id,
755            quantized_model_id,
756            quantized_filename,
757            topology,
758            dtype: _,
759            max_seq_len: _,
760            max_batch_size: _,
761        } => GGUFLoaderBuilder::new(
762            args.chat_template,
763            Some(tok_model_id),
764            quantized_model_id,
765            quantized_filename
766                .split(GGUF_MULTI_FILE_DELIMITER)
767                .map(ToOwned::to_owned)
768                .collect::<Vec<_>>(),
769            GGUFSpecificConfig {
770                topology: Topology::from_option_path(topology)?,
771            },
772            args.no_kv_cache,
773            args.jinja_explicit,
774        )
775        .build(),
776        TomlModelSelected::XLoraGGUF {
777            tok_model_id,
778            quantized_model_id,
779            quantized_filename,
780            xlora_model_id,
781            order,
782            tgt_non_granular_index,
783            topology,
784            dtype: _,
785            max_seq_len: _,
786            max_batch_size: _,
787        } => GGUFLoaderBuilder::new(
788            args.chat_template,
789            tok_model_id,
790            quantized_model_id,
791            quantized_filename
792                .split(GGUF_MULTI_FILE_DELIMITER)
793                .map(ToOwned::to_owned)
794                .collect::<Vec<_>>(),
795            GGUFSpecificConfig {
796                topology: Topology::from_option_path(topology)?,
797            },
798            args.no_kv_cache,
799            args.jinja_explicit,
800        )
801        .with_xlora(
802            xlora_model_id,
803            serde_json::from_reader(
804                File::open(order.clone())
805                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
806            )?,
807            args.no_kv_cache,
808            tgt_non_granular_index,
809        )
810        .build(),
811        TomlModelSelected::LoraGGUF {
812            tok_model_id,
813            quantized_model_id,
814            quantized_filename,
815            adapters_model_id,
816            order,
817            topology,
818            ..
819        } => GGUFLoaderBuilder::new(
820            args.chat_template,
821            tok_model_id,
822            quantized_model_id,
823            quantized_filename
824                .split(GGUF_MULTI_FILE_DELIMITER)
825                .map(ToOwned::to_owned)
826                .collect::<Vec<_>>(),
827            GGUFSpecificConfig {
828                topology: Topology::from_option_path(topology)?,
829            },
830            args.no_kv_cache,
831            args.jinja_explicit,
832        )
833        .with_lora(
834            adapters_model_id,
835            serde_json::from_reader(
836                File::open(order.clone())
837                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
838            )?,
839        )
840        .build(),
841        TomlModelSelected::GGML {
842            tok_model_id,
843            quantized_model_id,
844            quantized_filename,
845            gqa,
846            topology,
847            dtype: _,
848            max_seq_len: _,
849            max_batch_size: _,
850        } => GGMLLoaderBuilder::new(
851            GGMLSpecificConfig {
852                gqa,
853                topology: Topology::from_option_path(topology)?,
854            },
855            args.chat_template,
856            args.tokenizer_json,
857            Some(tok_model_id),
858            quantized_model_id,
859            quantized_filename,
860            args.no_kv_cache,
861            args.jinja_explicit,
862        )
863        .build(),
864        TomlModelSelected::XLoraGGML {
865            tok_model_id,
866            quantized_model_id,
867            quantized_filename,
868            xlora_model_id,
869            order,
870            tgt_non_granular_index,
871            gqa,
872            topology,
873            dtype: _,
874            max_seq_len: _,
875            max_batch_size: _,
876        } => GGMLLoaderBuilder::new(
877            GGMLSpecificConfig {
878                gqa,
879                topology: Topology::from_option_path(topology)?,
880            },
881            args.chat_template,
882            args.tokenizer_json,
883            tok_model_id,
884            quantized_model_id,
885            quantized_filename,
886            args.no_kv_cache,
887            args.jinja_explicit,
888        )
889        .with_xlora(
890            xlora_model_id,
891            serde_json::from_reader(
892                File::open(order.clone())
893                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
894            )?,
895            args.no_kv_cache,
896            tgt_non_granular_index,
897        )
898        .build(),
899        TomlModelSelected::LoraGGML {
900            tok_model_id,
901            quantized_model_id,
902            quantized_filename,
903            adapters_model_id,
904            order,
905            gqa,
906            topology,
907            dtype: _,
908            max_seq_len: _,
909            max_batch_size: _,
910        } => GGMLLoaderBuilder::new(
911            GGMLSpecificConfig {
912                gqa,
913                topology: Topology::from_option_path(topology)?,
914            },
915            args.chat_template,
916            args.tokenizer_json,
917            tok_model_id,
918            quantized_model_id,
919            quantized_filename,
920            args.no_kv_cache,
921            args.jinja_explicit,
922        )
923        .with_lora(
924            adapters_model_id,
925            serde_json::from_reader(
926                File::open(order.clone())
927                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
928            )?,
929        )
930        .build(),
931        TomlModelSelected::VisionPlain {
932            model_id,
933            arch,
934            dtype: _,
935            topology,
936            write_uqff,
937            from_uqff,
938            max_edge,
939            calibration_file,
940            max_seq_len: _,
941            max_batch_size: _,
942            max_num_images: _,
943            max_image_length: _,
944            imatrix,
945            hf_cache_path,
946        } => VisionLoaderBuilder::new(
947            VisionSpecificConfig {
948                topology: Topology::from_option_path(topology)?,
949                write_uqff,
950                from_uqff: from_uqff.map(|x| {
951                    x.split(UQFF_MULTI_FILE_DELIMITER)
952                        .map(PathBuf::from_str)
953                        .map(|x| x.unwrap())
954                        .collect::<Vec<_>>()
955                }),
956                max_edge,
957                calibration_file,
958                imatrix,
959                hf_cache_path,
960                matformer_config_path: None,
961                matformer_slice_name: None,
962            },
963            args.chat_template,
964            args.tokenizer_json,
965            Some(model_id),
966            args.jinja_explicit,
967        )
968        .build(arch),
969        TomlModelSelected::Embedding {
970            model_id,
971            tokenizer_json,
972            arch,
973            dtype: _,
974            topology,
975            write_uqff,
976            from_uqff,
977            hf_cache_path,
978        } => EmbeddingLoaderBuilder::new(
979            EmbeddingSpecificConfig {
980                topology: Topology::from_option_path(topology)?,
981                write_uqff,
982                from_uqff: from_uqff.map(|x| {
983                    x.split(UQFF_MULTI_FILE_DELIMITER)
984                        .map(PathBuf::from_str)
985                        .map(|x| x.unwrap())
986                        .collect::<Vec<_>>()
987                }),
988                hf_cache_path,
989            },
990            tokenizer_json,
991            Some(model_id),
992        )
993        .build(arch),
994    };
995    Ok(loader)
996}
997
998impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
999    type Error = anyhow::Error;
1000    fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
1001        let (selector, args) = self;
1002        let args = TomlLoaderInnerParams {
1003            chat_template: args.chat_template,
1004            no_kv_cache: args.no_kv_cache,
1005            tokenizer_json: selector.tokenizer_json,
1006            jinja_explicit: args.jinja_explicit,
1007        };
1008        let loader = loader_from_selected(args.clone(), selector.model)?;
1009        let loader = if let Some(speculative) = selector.speculative {
1010            let draft_loader = loader_from_selected(args, speculative.draft_model)?;
1011            Box::new(SpeculativeLoader {
1012                target: loader,
1013                draft: draft_loader,
1014                config: SpeculativeConfig {
1015                    gamma: speculative.gamma,
1016                },
1017            })
1018        } else {
1019            loader
1020        };
1021        let loader = if let Some(AnyMoeTomlModelSelected {
1022            config,
1023            dataset_json,
1024            prefix,
1025            mlp,
1026            model_ids,
1027            layers,
1028        }) = selector.anymoe
1029        {
1030            Box::new(AnyMoeLoader {
1031                target: loader,
1032                config,
1033                path: dataset_json,
1034                prefix,
1035                mlp,
1036                model_ids,
1037                layers,
1038            })
1039        } else {
1040            loader
1041        };
1042        Ok(loader)
1043    }
1044}