mistralrs_core/
toml_selector.rs

1use std::{fs::File, num::NonZeroUsize, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7    amoe::AnyMoeConfig, pipeline::IsqOrganization, AnyMoeLoader, AutoDeviceMapParams,
8    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
9    ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
10    SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
11    GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
12};
13
14fn default_one() -> usize {
15    1
16}
17
18fn default_dtype() -> ModelDType {
19    ModelDType::Auto
20}
21
22fn default_empty_vec_usize() -> Vec<usize> {
23    Vec::new()
24}
25
26fn default_max_seq_len() -> usize {
27    AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
28}
29
30fn default_max_batch_size() -> usize {
31    AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
32}
33
34fn default_max_num_images() -> usize {
35    AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
36}
37
38fn default_max_image_length() -> usize {
39    AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44pub enum TomlModelSelected {
45    /// Select a plain model, without quantization or adapters
46    Plain {
47        /// Model ID to load from. This may be a HF hub repo or a local path.
48        model_id: String,
49
50        /// The architecture of the model.
51        arch: Option<NormalLoaderType>,
52
53        /// Model data type. Defaults to `auto`.
54        #[serde(default = "default_dtype")]
55        dtype: ModelDType,
56
57        /// Path to a topology YAML file.
58        topology: Option<String>,
59
60        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
61        organization: Option<IsqOrganization>,
62
63        /// UQFF path to write to.
64        write_uqff: Option<PathBuf>,
65
66        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
67        from_uqff: Option<String>,
68
69        /// .imatrix file to enhance GGUF quantizations with.
70        /// Incompatible with `--imatrix/-i`
71        imatrix: Option<PathBuf>,
72
73        /// Generate and utilize an imatrix to enhance GGUF quantizations.
74        /// Incompatible with `--imatrix/-i`
75        calibration_file: Option<PathBuf>,
76
77        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
78        #[serde(default = "default_max_seq_len")]
79        max_seq_len: usize,
80
81        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
82        #[serde(default = "default_max_batch_size")]
83        max_batch_size: usize,
84
85        /// Cache path for Hugging Face models downloaded locally
86        hf_cache_path: Option<PathBuf>,
87    },
88
89    /// Select an X-LoRA architecture
90    XLora {
91        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
92        model_id: Option<String>,
93
94        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
95        xlora_model_id: String,
96
97        /// Ordering JSON file
98        order: String,
99
100        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
101        /// This makes the maximum running sequences 1.
102        tgt_non_granular_index: Option<usize>,
103
104        /// The architecture of the model.
105        arch: Option<NormalLoaderType>,
106
107        /// Model data type. Defaults to `auto`.
108        #[serde(default = "default_dtype")]
109        dtype: ModelDType,
110
111        /// Path to a topology YAML file.
112        topology: Option<String>,
113
114        /// UQFF path to write to.
115        write_uqff: Option<PathBuf>,
116
117        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
118        from_uqff: Option<String>,
119
120        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
121        #[serde(default = "default_max_seq_len")]
122        max_seq_len: usize,
123
124        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
125        #[serde(default = "default_max_batch_size")]
126        max_batch_size: usize,
127
128        /// Cache path for Hugging Face models downloaded locally
129        hf_cache_path: Option<PathBuf>,
130    },
131
132    /// Select a LoRA architecture
133    Lora {
134        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
135        model_id: Option<String>,
136
137        /// Model IDs to load LoRA from. This may be a HF hub repo or a local path. Specify multiple with a semicolon.
138        adapter_model_ids: String,
139
140        /// The architecture of the model.
141        arch: Option<NormalLoaderType>,
142
143        /// Model data type. Defaults to `auto`.
144        #[serde(default = "default_dtype")]
145        dtype: ModelDType,
146
147        /// Path to a topology YAML file.
148        topology: Option<String>,
149
150        /// UQFF path to write to.
151        write_uqff: Option<PathBuf>,
152
153        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
154        from_uqff: Option<String>,
155
156        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
157        #[serde(default = "default_max_seq_len")]
158        max_seq_len: usize,
159
160        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
161        #[serde(default = "default_max_batch_size")]
162        max_batch_size: usize,
163
164        /// Cache path for Hugging Face models downloaded locally
165        hf_cache_path: Option<PathBuf>,
166    },
167
168    /// Select a GGUF model.
169    #[allow(clippy::upper_case_acronyms)]
170    GGUF {
171        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
172        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
173        /// removing all remote accesses.
174        tok_model_id: String,
175
176        /// Quantized model ID to find the `quantized_filename`.
177        /// This may be a HF hub repo or a local path.
178        quantized_model_id: String,
179
180        /// Quantized filename(s).
181        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
182        quantized_filename: String,
183
184        /// Model data type. Defaults to `auto`.
185        #[serde(default = "default_dtype")]
186        dtype: ModelDType,
187
188        /// Path to a topology YAML file.
189        topology: Option<String>,
190
191        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
192        #[serde(default = "default_max_seq_len")]
193        max_seq_len: usize,
194
195        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
196        #[serde(default = "default_max_batch_size")]
197        max_batch_size: usize,
198    },
199
200    /// Select a GGUF model with X-LoRA.
201    XLoraGGUF {
202        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
203        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
204        /// removing all remote accesses.
205        tok_model_id: Option<String>,
206
207        /// Quantized model ID to find the `quantized_filename`.
208        /// This may be a HF hub repo or a local path.
209        quantized_model_id: String,
210
211        /// Quantized filename(s).
212        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
213        quantized_filename: String,
214
215        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
216        xlora_model_id: String,
217
218        /// Ordering JSON file
219        order: String,
220
221        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
222        /// This makes the maximum running sequences 1.
223        tgt_non_granular_index: Option<usize>,
224
225        /// Model data type. Defaults to `auto`.
226        #[serde(default = "default_dtype")]
227        dtype: ModelDType,
228
229        /// Path to a topology YAML file.
230        topology: Option<String>,
231
232        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
233        #[serde(default = "default_max_seq_len")]
234        max_seq_len: usize,
235
236        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
237        #[serde(default = "default_max_batch_size")]
238        max_batch_size: usize,
239    },
240
241    /// Select a GGUF model with LoRA.
242    LoraGGUF {
243        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
244        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
245        /// removing all remote accesses.
246        tok_model_id: Option<String>,
247
248        /// Quantized model ID to find the `quantized_filename`.
249        /// This may be a HF hub repo or a local path.
250        quantized_model_id: String,
251
252        /// Quantized filename(s).
253        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
254        quantized_filename: String,
255
256        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
257        adapters_model_id: String,
258
259        /// Ordering JSON file
260        order: String,
261
262        /// Model data type. Defaults to `auto`.
263        #[serde(default = "default_dtype")]
264        dtype: ModelDType,
265
266        /// Path to a topology YAML file.
267        topology: Option<String>,
268
269        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
270        #[serde(default = "default_max_seq_len")]
271        max_seq_len: usize,
272
273        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
274        #[serde(default = "default_max_batch_size")]
275        max_batch_size: usize,
276    },
277
278    /// Select a GGML model.
279    #[allow(clippy::upper_case_acronyms)]
280    GGML {
281        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
282        tok_model_id: String,
283
284        /// Quantized model ID to find the `quantized_filename`.
285        /// This may be a HF hub repo or a local path.
286        quantized_model_id: String,
287
288        /// Quantized filename.
289        quantized_filename: String,
290
291        /// GQA value
292        #[serde(default = "default_one")]
293        gqa: usize,
294
295        /// Model data type. Defaults to `auto`.
296        #[serde(default = "default_dtype")]
297        dtype: ModelDType,
298
299        /// Path to a topology YAML file.
300        topology: Option<String>,
301
302        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
303        #[serde(default = "default_max_seq_len")]
304        max_seq_len: usize,
305
306        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
307        #[serde(default = "default_max_batch_size")]
308        max_batch_size: usize,
309    },
310
311    /// Select a GGML model with X-LoRA.
312    XLoraGGML {
313        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
314        tok_model_id: Option<String>,
315
316        /// Quantized model ID to find the `quantized_filename`.
317        /// This may be a HF hub repo or a local path.
318        quantized_model_id: String,
319
320        /// Quantized filename.
321        quantized_filename: String,
322
323        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
324        xlora_model_id: String,
325
326        /// Ordering JSON file
327        order: String,
328
329        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
330        /// This makes the maximum running sequences 1.
331        tgt_non_granular_index: Option<usize>,
332
333        /// GQA value
334        #[serde(default = "default_one")]
335        gqa: usize,
336
337        /// Model data type. Defaults to `auto`.
338        #[serde(default = "default_dtype")]
339        dtype: ModelDType,
340
341        /// Path to a topology YAML file.
342        topology: Option<String>,
343
344        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
345        #[serde(default = "default_max_seq_len")]
346        max_seq_len: usize,
347
348        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
349        #[serde(default = "default_max_batch_size")]
350        max_batch_size: usize,
351    },
352
353    /// Select a GGML model with LoRA.
354    LoraGGML {
355        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
356        tok_model_id: Option<String>,
357
358        /// Quantized model ID to find the `quantized_filename`.
359        /// This may be a HF hub repo or a local path.
360        quantized_model_id: String,
361
362        /// Quantized filename.
363        quantized_filename: String,
364
365        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
366        adapters_model_id: String,
367
368        /// Ordering JSON file
369        order: String,
370
371        /// GQA value
372        #[serde(default = "default_one")]
373        gqa: usize,
374
375        /// Model data type. Defaults to `auto`.
376        #[serde(default = "default_dtype")]
377        dtype: ModelDType,
378
379        /// Path to a topology YAML file.
380        topology: Option<String>,
381
382        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
383        #[serde(default = "default_max_seq_len")]
384        max_seq_len: usize,
385
386        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
387        #[serde(default = "default_max_batch_size")]
388        max_batch_size: usize,
389    },
390
391    /// Select a vision plain model, without quantization or adapters
392    VisionPlain {
393        /// Model ID to load from. This may be a HF hub repo or a local path.
394        model_id: String,
395
396        /// The architecture of the model.
397        arch: VisionLoaderType,
398
399        /// Model data type. Defaults to `auto`.
400        #[serde(default = "default_dtype")]
401        dtype: ModelDType,
402
403        /// Path to a topology YAML file.
404        topology: Option<String>,
405
406        /// UQFF path to write to.
407        write_uqff: Option<PathBuf>,
408
409        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
410        from_uqff: Option<String>,
411
412        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
413        /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
414        max_edge: Option<u32>,
415
416        /// Generate and utilize an imatrix to enhance GGUF quantizations.
417        calibration_file: Option<PathBuf>,
418
419        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
420        imatrix: Option<PathBuf>,
421
422        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
423        #[serde(default = "default_max_seq_len")]
424        max_seq_len: usize,
425
426        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
427        #[serde(default = "default_max_batch_size")]
428        max_batch_size: usize,
429
430        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
431        #[serde(default = "default_max_num_images")]
432        max_num_images: usize,
433
434        /// Maximum expected image size will have this edge length on both edges.
435        /// This affects automatic device mapping but is not a hard limit.
436        #[serde(default = "default_max_image_length")]
437        max_image_length: usize,
438
439        /// Cache path for Hugging Face models downloaded locally
440        hf_cache_path: Option<PathBuf>,
441    },
442}
443
444#[derive(Deserialize)]
445pub struct SpeculativeTomlModelSelected {
446    /// Gamma value for the model
447    gamma: usize,
448
449    /// Base model
450    draft_model: TomlModelSelected,
451}
452
453#[derive(Deserialize)]
454pub struct AnyMoeTomlModelSelected {
455    /// Config
456    config: AnyMoeConfig,
457
458    /// Base model
459    dataset_json: String,
460
461    /// Prefix of the mlp key (the part before the layer number: "a.b.c" in "a.b.c.0.mlp")
462    prefix: String,
463
464    /// Name of the mlp key (the part before the layer number: "mlp" in "a.b.c.0.mlp")
465    mlp: String,
466
467    /// Expert model ids
468    model_ids: Vec<String>,
469
470    /// Layer ids (zero indexed) of layers to apply AnyMoE to, if empty will use all
471    #[serde(default = "default_empty_vec_usize")]
472    layers: Vec<usize>,
473}
474
475#[derive(Deserialize)]
476pub struct TomlSelector {
477    /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
478    tokenizer_json: Option<String>,
479
480    /// Selected model
481    model: TomlModelSelected,
482
483    /// Speculative model selector
484    speculative: Option<SpeculativeTomlModelSelected>,
485
486    /// AnyMoE config
487    anymoe: Option<AnyMoeTomlModelSelected>,
488}
489
490#[derive(Clone)]
491struct TomlLoaderInnerParams {
492    use_flash_attn: bool,
493    chat_template: Option<String>,
494    no_kv_cache: bool,
495    tokenizer_json: Option<String>,
496    prompt_chunksize: Option<NonZeroUsize>,
497    jinja_explicit: Option<String>,
498}
499
500pub struct TomlLoaderArgs {
501    pub use_flash_attn: bool,
502    pub chat_template: Option<String>,
503    pub no_kv_cache: bool,
504    pub prompt_chunksize: Option<NonZeroUsize>,
505    pub jinja_explicit: Option<String>,
506}
507
508pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
509    match model.model {
510        TomlModelSelected::Plain { dtype, .. }
511        | TomlModelSelected::Lora { dtype, .. }
512        | TomlModelSelected::XLora { dtype, .. }
513        | TomlModelSelected::VisionPlain { dtype, .. }
514        | TomlModelSelected::GGUF { dtype, .. }
515        | TomlModelSelected::GGML { dtype, .. }
516        | TomlModelSelected::XLoraGGUF { dtype, .. }
517        | TomlModelSelected::XLoraGGML { dtype, .. }
518        | TomlModelSelected::LoraGGUF { dtype, .. }
519        | TomlModelSelected::LoraGGML { dtype, .. } => dtype,
520    }
521}
522
523pub fn get_toml_selected_model_device_map_params(
524    model: &TomlSelector,
525) -> anyhow::Result<AutoDeviceMapParams> {
526    match model.model {
527        TomlModelSelected::Plain {
528            max_seq_len,
529            max_batch_size,
530            ..
531        }
532        | TomlModelSelected::Lora {
533            max_seq_len,
534            max_batch_size,
535            ..
536        }
537        | TomlModelSelected::XLora {
538            max_seq_len,
539            max_batch_size,
540            ..
541        }
542        | TomlModelSelected::GGML {
543            max_seq_len,
544            max_batch_size,
545            ..
546        }
547        | TomlModelSelected::GGUF {
548            max_seq_len,
549            max_batch_size,
550            ..
551        }
552        | TomlModelSelected::XLoraGGUF {
553            max_seq_len,
554            max_batch_size,
555            ..
556        }
557        | TomlModelSelected::XLoraGGML {
558            max_seq_len,
559            max_batch_size,
560            ..
561        }
562        | TomlModelSelected::LoraGGUF {
563            max_seq_len,
564            max_batch_size,
565            ..
566        }
567        | TomlModelSelected::LoraGGML {
568            max_seq_len,
569            max_batch_size,
570            ..
571        } => Ok(AutoDeviceMapParams::Text {
572            max_seq_len,
573            max_batch_size,
574        }),
575        TomlModelSelected::VisionPlain {
576            max_seq_len,
577            max_batch_size,
578            max_image_length,
579            max_num_images,
580            ..
581        } => Ok(AutoDeviceMapParams::Vision {
582            max_seq_len,
583            max_batch_size,
584            max_image_shape: (max_image_length, max_image_length),
585            max_num_images,
586        }),
587    }
588}
589
590fn loader_from_selected(
591    args: TomlLoaderInnerParams,
592    model: TomlModelSelected,
593) -> anyhow::Result<Box<dyn Loader>> {
594    let use_flash_attn = args.use_flash_attn;
595    let loader: Box<dyn Loader> = match model {
596        TomlModelSelected::Plain {
597            model_id,
598            arch,
599            dtype: _,
600            topology,
601            organization,
602            write_uqff,
603            from_uqff,
604            imatrix,
605            calibration_file,
606            max_seq_len: _,
607            max_batch_size: _,
608            hf_cache_path,
609        } => NormalLoaderBuilder::new(
610            NormalSpecificConfig {
611                use_flash_attn,
612                prompt_chunksize: args.prompt_chunksize,
613                topology: Topology::from_option_path(topology)?,
614                organization: organization.unwrap_or_default(),
615                write_uqff,
616                from_uqff: from_uqff.map(|x| {
617                    x.split(UQFF_MULTI_FILE_DELIMITER)
618                        .map(PathBuf::from_str)
619                        .map(|x| x.unwrap())
620                        .collect::<Vec<_>>()
621                }),
622                imatrix,
623                calibration_file,
624                hf_cache_path,
625            },
626            args.chat_template,
627            args.tokenizer_json,
628            Some(model_id),
629            args.no_kv_cache,
630            args.jinja_explicit,
631        )
632        .build(arch)?,
633        TomlModelSelected::XLora {
634            model_id,
635            xlora_model_id,
636            order,
637            tgt_non_granular_index,
638            arch,
639            dtype: _,
640            topology,
641            write_uqff,
642            from_uqff,
643            max_seq_len: _,
644            max_batch_size: _,
645            hf_cache_path,
646        } => NormalLoaderBuilder::new(
647            NormalSpecificConfig {
648                use_flash_attn,
649                prompt_chunksize: args.prompt_chunksize,
650                topology: Topology::from_option_path(topology)?,
651                organization: Default::default(),
652                write_uqff,
653                from_uqff: from_uqff.map(|x| {
654                    x.split(UQFF_MULTI_FILE_DELIMITER)
655                        .map(PathBuf::from_str)
656                        .map(|x| x.unwrap())
657                        .collect::<Vec<_>>()
658                }),
659                imatrix: None,
660                calibration_file: None,
661                hf_cache_path,
662            },
663            args.chat_template,
664            args.tokenizer_json,
665            model_id,
666            args.no_kv_cache,
667            args.jinja_explicit,
668        )
669        .with_xlora(
670            xlora_model_id,
671            serde_json::from_reader(
672                File::open(order.clone())
673                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
674            )?,
675            args.no_kv_cache,
676            tgt_non_granular_index,
677        )
678        .build(arch)?,
679        TomlModelSelected::Lora {
680            model_id,
681            adapter_model_ids,
682            arch,
683            dtype: _,
684            topology,
685            write_uqff,
686            from_uqff,
687            max_seq_len: _,
688            max_batch_size: _,
689            hf_cache_path,
690        } => NormalLoaderBuilder::new(
691            NormalSpecificConfig {
692                use_flash_attn,
693                prompt_chunksize: args.prompt_chunksize,
694                topology: Topology::from_option_path(topology)?,
695                organization: Default::default(),
696                write_uqff,
697                from_uqff: from_uqff.map(|x| {
698                    x.split(UQFF_MULTI_FILE_DELIMITER)
699                        .map(PathBuf::from_str)
700                        .map(|x| x.unwrap())
701                        .collect::<Vec<_>>()
702                }),
703                imatrix: None,
704                calibration_file: None,
705                hf_cache_path,
706            },
707            args.chat_template,
708            args.tokenizer_json,
709            model_id,
710            args.no_kv_cache,
711            args.jinja_explicit,
712        )
713        .with_lora(
714            adapter_model_ids
715                .split(MULTI_LORA_DELIMITER)
716                .map(ToString::to_string)
717                .collect(),
718        )
719        .build(arch)?,
720        TomlModelSelected::GGUF {
721            tok_model_id,
722            quantized_model_id,
723            quantized_filename,
724            topology,
725            dtype: _,
726            max_seq_len: _,
727            max_batch_size: _,
728        } => GGUFLoaderBuilder::new(
729            args.chat_template,
730            Some(tok_model_id),
731            quantized_model_id,
732            quantized_filename
733                .split(GGUF_MULTI_FILE_DELIMITER)
734                .map(ToOwned::to_owned)
735                .collect::<Vec<_>>(),
736            GGUFSpecificConfig {
737                prompt_chunksize: args.prompt_chunksize,
738                topology: Topology::from_option_path(topology)?,
739            },
740            args.no_kv_cache,
741            args.jinja_explicit,
742        )
743        .build(),
744        TomlModelSelected::XLoraGGUF {
745            tok_model_id,
746            quantized_model_id,
747            quantized_filename,
748            xlora_model_id,
749            order,
750            tgt_non_granular_index,
751            topology,
752            dtype: _,
753            max_seq_len: _,
754            max_batch_size: _,
755        } => GGUFLoaderBuilder::new(
756            args.chat_template,
757            tok_model_id,
758            quantized_model_id,
759            quantized_filename
760                .split(GGUF_MULTI_FILE_DELIMITER)
761                .map(ToOwned::to_owned)
762                .collect::<Vec<_>>(),
763            GGUFSpecificConfig {
764                prompt_chunksize: args.prompt_chunksize,
765                topology: Topology::from_option_path(topology)?,
766            },
767            args.no_kv_cache,
768            args.jinja_explicit,
769        )
770        .with_xlora(
771            xlora_model_id,
772            serde_json::from_reader(
773                File::open(order.clone())
774                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
775            )?,
776            args.no_kv_cache,
777            tgt_non_granular_index,
778        )
779        .build(),
780        TomlModelSelected::LoraGGUF {
781            tok_model_id,
782            quantized_model_id,
783            quantized_filename,
784            adapters_model_id,
785            order,
786            topology,
787            ..
788        } => GGUFLoaderBuilder::new(
789            args.chat_template,
790            tok_model_id,
791            quantized_model_id,
792            quantized_filename
793                .split(GGUF_MULTI_FILE_DELIMITER)
794                .map(ToOwned::to_owned)
795                .collect::<Vec<_>>(),
796            GGUFSpecificConfig {
797                prompt_chunksize: args.prompt_chunksize,
798                topology: Topology::from_option_path(topology)?,
799            },
800            args.no_kv_cache,
801            args.jinja_explicit,
802        )
803        .with_lora(
804            adapters_model_id,
805            serde_json::from_reader(
806                File::open(order.clone())
807                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
808            )?,
809        )
810        .build(),
811        TomlModelSelected::GGML {
812            tok_model_id,
813            quantized_model_id,
814            quantized_filename,
815            gqa,
816            topology,
817            dtype: _,
818            max_seq_len: _,
819            max_batch_size: _,
820        } => GGMLLoaderBuilder::new(
821            GGMLSpecificConfig {
822                gqa,
823                prompt_chunksize: args.prompt_chunksize,
824                topology: Topology::from_option_path(topology)?,
825            },
826            args.chat_template,
827            args.tokenizer_json,
828            Some(tok_model_id),
829            quantized_model_id,
830            quantized_filename,
831            args.no_kv_cache,
832            args.jinja_explicit,
833        )
834        .build(),
835        TomlModelSelected::XLoraGGML {
836            tok_model_id,
837            quantized_model_id,
838            quantized_filename,
839            xlora_model_id,
840            order,
841            tgt_non_granular_index,
842            gqa,
843            topology,
844            dtype: _,
845            max_seq_len: _,
846            max_batch_size: _,
847        } => GGMLLoaderBuilder::new(
848            GGMLSpecificConfig {
849                gqa,
850                prompt_chunksize: args.prompt_chunksize,
851                topology: Topology::from_option_path(topology)?,
852            },
853            args.chat_template,
854            args.tokenizer_json,
855            tok_model_id,
856            quantized_model_id,
857            quantized_filename,
858            args.no_kv_cache,
859            args.jinja_explicit,
860        )
861        .with_xlora(
862            xlora_model_id,
863            serde_json::from_reader(
864                File::open(order.clone())
865                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
866            )?,
867            args.no_kv_cache,
868            tgt_non_granular_index,
869        )
870        .build(),
871        TomlModelSelected::LoraGGML {
872            tok_model_id,
873            quantized_model_id,
874            quantized_filename,
875            adapters_model_id,
876            order,
877            gqa,
878            topology,
879            dtype: _,
880            max_seq_len: _,
881            max_batch_size: _,
882        } => GGMLLoaderBuilder::new(
883            GGMLSpecificConfig {
884                gqa,
885                prompt_chunksize: args.prompt_chunksize,
886                topology: Topology::from_option_path(topology)?,
887            },
888            args.chat_template,
889            args.tokenizer_json,
890            tok_model_id,
891            quantized_model_id,
892            quantized_filename,
893            args.no_kv_cache,
894            args.jinja_explicit,
895        )
896        .with_lora(
897            adapters_model_id,
898            serde_json::from_reader(
899                File::open(order.clone())
900                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
901            )?,
902        )
903        .build(),
904        TomlModelSelected::VisionPlain {
905            model_id,
906            arch,
907            dtype: _,
908            topology,
909            write_uqff,
910            from_uqff,
911            max_edge,
912            calibration_file,
913            max_seq_len: _,
914            max_batch_size: _,
915            max_num_images: _,
916            max_image_length: _,
917            imatrix,
918            hf_cache_path,
919        } => VisionLoaderBuilder::new(
920            VisionSpecificConfig {
921                use_flash_attn,
922                prompt_chunksize: args.prompt_chunksize,
923                topology: Topology::from_option_path(topology)?,
924                write_uqff,
925                from_uqff: from_uqff.map(|x| {
926                    x.split(UQFF_MULTI_FILE_DELIMITER)
927                        .map(PathBuf::from_str)
928                        .map(|x| x.unwrap())
929                        .collect::<Vec<_>>()
930                }),
931                max_edge,
932                calibration_file,
933                imatrix,
934                hf_cache_path,
935            },
936            args.chat_template,
937            args.tokenizer_json,
938            Some(model_id),
939            args.jinja_explicit,
940        )
941        .build(arch),
942    };
943    Ok(loader)
944}
945
946impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
947    type Error = anyhow::Error;
948    fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
949        let (selector, args) = self;
950        let args = TomlLoaderInnerParams {
951            use_flash_attn: args.use_flash_attn,
952            chat_template: args.chat_template,
953            no_kv_cache: args.no_kv_cache,
954            tokenizer_json: selector.tokenizer_json,
955            prompt_chunksize: args.prompt_chunksize,
956            jinja_explicit: args.jinja_explicit,
957        };
958        let loader = loader_from_selected(args.clone(), selector.model)?;
959        let loader = if let Some(speculative) = selector.speculative {
960            let draft_loader = loader_from_selected(args, speculative.draft_model)?;
961            Box::new(SpeculativeLoader {
962                target: loader,
963                draft: draft_loader,
964                config: SpeculativeConfig {
965                    gamma: speculative.gamma,
966                },
967            })
968        } else {
969            loader
970        };
971        let loader = if let Some(AnyMoeTomlModelSelected {
972            config,
973            dataset_json,
974            prefix,
975            mlp,
976            model_ids,
977            layers,
978        }) = selector.anymoe
979        {
980            Box::new(AnyMoeLoader {
981                target: loader,
982                config,
983                path: dataset_json,
984                prefix,
985                mlp,
986                model_ids,
987                layers,
988            })
989        } else {
990            loader
991        };
992        Ok(loader)
993    }
994}