mistralrs_core/
toml_selector.rs

1use std::{fs::File, num::NonZeroUsize, path::PathBuf, str::FromStr};
2
3use mistralrs_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7    amoe::AnyMoeConfig, pipeline::IsqOrganization, AnyMoeLoader, AutoDeviceMapParams,
8    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
9    ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig,
10    SpeculativeLoader, Topology, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
11    GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
12};
13
14fn default_one() -> usize {
15    1
16}
17
18fn default_dtype() -> ModelDType {
19    ModelDType::Auto
20}
21
22fn default_empty_vec_usize() -> Vec<usize> {
23    Vec::new()
24}
25
26fn default_max_seq_len() -> usize {
27    AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
28}
29
30fn default_max_batch_size() -> usize {
31    AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
32}
33
34fn default_max_num_images() -> usize {
35    AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
36}
37
38fn default_max_image_length() -> usize {
39    AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44pub enum TomlModelSelected {
45    /// Select a plain model, without quantization or adapters
46    Plain {
47        /// Model ID to load from. This may be a HF hub repo or a local path.
48        model_id: String,
49
50        /// The architecture of the model.
51        arch: Option<NormalLoaderType>,
52
53        /// Model data type. Defaults to `auto`.
54        #[serde(default = "default_dtype")]
55        dtype: ModelDType,
56
57        /// Path to a topology YAML file.
58        topology: Option<String>,
59
60        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
61        organization: Option<IsqOrganization>,
62
63        /// UQFF path to write to.
64        write_uqff: Option<PathBuf>,
65
66        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
67        from_uqff: Option<String>,
68
69        /// .imatrix file to enhance GGUF quantizations with.
70        /// Incompatible with `--imatrix/-i`
71        imatrix: Option<PathBuf>,
72
73        /// Generate and utilize an imatrix to enhance GGUF quantizations.
74        /// Incompatible with `--imatrix/-i`
75        calibration_file: Option<PathBuf>,
76
77        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
78        #[serde(default = "default_max_seq_len")]
79        max_seq_len: usize,
80
81        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
82        #[serde(default = "default_max_batch_size")]
83        max_batch_size: usize,
84
85        /// Cache path for Hugging Face models downloaded locally
86        hf_cache_path: Option<PathBuf>,
87    },
88
89    /// Select an X-LoRA architecture
90    XLora {
91        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
92        model_id: Option<String>,
93
94        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
95        xlora_model_id: String,
96
97        /// Ordering JSON file
98        order: String,
99
100        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
101        /// This makes the maximum running sequences 1.
102        tgt_non_granular_index: Option<usize>,
103
104        /// The architecture of the model.
105        arch: Option<NormalLoaderType>,
106
107        /// Model data type. Defaults to `auto`.
108        #[serde(default = "default_dtype")]
109        dtype: ModelDType,
110
111        /// Path to a topology YAML file.
112        topology: Option<String>,
113
114        /// UQFF path to write to.
115        write_uqff: Option<PathBuf>,
116
117        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
118        from_uqff: Option<String>,
119
120        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
121        #[serde(default = "default_max_seq_len")]
122        max_seq_len: usize,
123
124        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
125        #[serde(default = "default_max_batch_size")]
126        max_batch_size: usize,
127
128        /// Cache path for Hugging Face models downloaded locally
129        hf_cache_path: Option<PathBuf>,
130    },
131
132    /// Select a LoRA architecture
133    Lora {
134        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
135        model_id: Option<String>,
136
137        /// Model IDs to load LoRA from. This may be a HF hub repo or a local path. Specify multiple with a semicolon.
138        adapter_model_ids: String,
139
140        /// The architecture of the model.
141        arch: Option<NormalLoaderType>,
142
143        /// Model data type. Defaults to `auto`.
144        #[serde(default = "default_dtype")]
145        dtype: ModelDType,
146
147        /// Path to a topology YAML file.
148        topology: Option<String>,
149
150        /// UQFF path to write to.
151        write_uqff: Option<PathBuf>,
152
153        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
154        from_uqff: Option<String>,
155
156        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
157        #[serde(default = "default_max_seq_len")]
158        max_seq_len: usize,
159
160        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
161        #[serde(default = "default_max_batch_size")]
162        max_batch_size: usize,
163
164        /// Cache path for Hugging Face models downloaded locally
165        hf_cache_path: Option<PathBuf>,
166    },
167
168    /// Select a GGUF model.
169    #[allow(clippy::upper_case_acronyms)]
170    GGUF {
171        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
172        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
173        /// removing all remote accesses.
174        tok_model_id: String,
175
176        /// Quantized model ID to find the `quantized_filename`.
177        /// This may be a HF hub repo or a local path.
178        quantized_model_id: String,
179
180        /// Quantized filename(s).
181        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
182        quantized_filename: String,
183
184        /// Model data type. Defaults to `auto`.
185        #[serde(default = "default_dtype")]
186        dtype: ModelDType,
187
188        /// Path to a topology YAML file.
189        topology: Option<String>,
190
191        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
192        #[serde(default = "default_max_seq_len")]
193        max_seq_len: usize,
194
195        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
196        #[serde(default = "default_max_batch_size")]
197        max_batch_size: usize,
198    },
199
200    /// Select a GGUF model with X-LoRA.
201    XLoraGGUF {
202        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
203        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
204        /// removing all remote accesses.
205        tok_model_id: Option<String>,
206
207        /// Quantized model ID to find the `quantized_filename`.
208        /// This may be a HF hub repo or a local path.
209        quantized_model_id: String,
210
211        /// Quantized filename(s).
212        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
213        quantized_filename: String,
214
215        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
216        xlora_model_id: String,
217
218        /// Ordering JSON file
219        order: String,
220
221        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
222        /// This makes the maximum running sequences 1.
223        tgt_non_granular_index: Option<usize>,
224
225        /// Model data type. Defaults to `auto`.
226        #[serde(default = "default_dtype")]
227        dtype: ModelDType,
228
229        /// Path to a topology YAML file.
230        topology: Option<String>,
231
232        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
233        #[serde(default = "default_max_seq_len")]
234        max_seq_len: usize,
235
236        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
237        #[serde(default = "default_max_batch_size")]
238        max_batch_size: usize,
239    },
240
241    /// Select a GGUF model with LoRA.
242    LoraGGUF {
243        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
244        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
245        /// removing all remote accesses.
246        tok_model_id: Option<String>,
247
248        /// Quantized model ID to find the `quantized_filename`.
249        /// This may be a HF hub repo or a local path.
250        quantized_model_id: String,
251
252        /// Quantized filename(s).
253        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
254        quantized_filename: String,
255
256        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
257        adapters_model_id: String,
258
259        /// Ordering JSON file
260        order: String,
261
262        /// Model data type. Defaults to `auto`.
263        #[serde(default = "default_dtype")]
264        dtype: ModelDType,
265
266        /// Path to a topology YAML file.
267        topology: Option<String>,
268
269        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
270        #[serde(default = "default_max_seq_len")]
271        max_seq_len: usize,
272
273        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
274        #[serde(default = "default_max_batch_size")]
275        max_batch_size: usize,
276    },
277
278    /// Select a GGML model.
279    #[allow(clippy::upper_case_acronyms)]
280    GGML {
281        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
282        tok_model_id: String,
283
284        /// Quantized model ID to find the `quantized_filename`.
285        /// This may be a HF hub repo or a local path.
286        quantized_model_id: String,
287
288        /// Quantized filename.
289        quantized_filename: String,
290
291        /// GQA value
292        #[serde(default = "default_one")]
293        gqa: usize,
294
295        /// Model data type. Defaults to `auto`.
296        #[serde(default = "default_dtype")]
297        dtype: ModelDType,
298
299        /// Path to a topology YAML file.
300        topology: Option<String>,
301
302        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
303        #[serde(default = "default_max_seq_len")]
304        max_seq_len: usize,
305
306        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
307        #[serde(default = "default_max_batch_size")]
308        max_batch_size: usize,
309    },
310
311    /// Select a GGML model with X-LoRA.
312    XLoraGGML {
313        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
314        tok_model_id: Option<String>,
315
316        /// Quantized model ID to find the `quantized_filename`.
317        /// This may be a HF hub repo or a local path.
318        quantized_model_id: String,
319
320        /// Quantized filename.
321        quantized_filename: String,
322
323        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
324        xlora_model_id: String,
325
326        /// Ordering JSON file
327        order: String,
328
329        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
330        /// This makes the maximum running sequences 1.
331        tgt_non_granular_index: Option<usize>,
332
333        /// GQA value
334        #[serde(default = "default_one")]
335        gqa: usize,
336
337        /// Model data type. Defaults to `auto`.
338        #[serde(default = "default_dtype")]
339        dtype: ModelDType,
340
341        /// Path to a topology YAML file.
342        topology: Option<String>,
343
344        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
345        #[serde(default = "default_max_seq_len")]
346        max_seq_len: usize,
347
348        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
349        #[serde(default = "default_max_batch_size")]
350        max_batch_size: usize,
351    },
352
353    /// Select a GGML model with LoRA.
354    LoraGGML {
355        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
356        tok_model_id: Option<String>,
357
358        /// Quantized model ID to find the `quantized_filename`.
359        /// This may be a HF hub repo or a local path.
360        quantized_model_id: String,
361
362        /// Quantized filename.
363        quantized_filename: String,
364
365        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
366        adapters_model_id: String,
367
368        /// Ordering JSON file
369        order: String,
370
371        /// GQA value
372        #[serde(default = "default_one")]
373        gqa: usize,
374
375        /// Model data type. Defaults to `auto`.
376        #[serde(default = "default_dtype")]
377        dtype: ModelDType,
378
379        /// Path to a topology YAML file.
380        topology: Option<String>,
381
382        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
383        #[serde(default = "default_max_seq_len")]
384        max_seq_len: usize,
385
386        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
387        #[serde(default = "default_max_batch_size")]
388        max_batch_size: usize,
389    },
390
391    /// Select a vision plain model, without quantization or adapters
392    VisionPlain {
393        /// Model ID to load from. This may be a HF hub repo or a local path.
394        model_id: String,
395
396        /// The architecture of the model.
397        arch: Option<VisionLoaderType>,
398
399        /// Model data type. Defaults to `auto`.
400        #[serde(default = "default_dtype")]
401        dtype: ModelDType,
402
403        /// Path to a topology YAML file.
404        topology: Option<String>,
405
406        /// UQFF path to write to.
407        write_uqff: Option<PathBuf>,
408
409        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
410        from_uqff: Option<String>,
411
412        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
413        /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
414        max_edge: Option<u32>,
415
416        /// Generate and utilize an imatrix to enhance GGUF quantizations.
417        calibration_file: Option<PathBuf>,
418
419        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
420        imatrix: Option<PathBuf>,
421
422        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
423        #[serde(default = "default_max_seq_len")]
424        max_seq_len: usize,
425
426        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
427        #[serde(default = "default_max_batch_size")]
428        max_batch_size: usize,
429
430        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
431        #[serde(default = "default_max_num_images")]
432        max_num_images: usize,
433
434        /// Maximum expected image size will have this edge length on both edges.
435        /// This affects automatic device mapping but is not a hard limit.
436        #[serde(default = "default_max_image_length")]
437        max_image_length: usize,
438
439        /// Cache path for Hugging Face models downloaded locally
440        hf_cache_path: Option<PathBuf>,
441    },
442}
443
444#[derive(Deserialize)]
445pub struct SpeculativeTomlModelSelected {
446    /// Gamma value for the model
447    gamma: usize,
448
449    /// Base model
450    draft_model: TomlModelSelected,
451}
452
453#[derive(Deserialize)]
454pub struct AnyMoeTomlModelSelected {
455    /// Config
456    config: AnyMoeConfig,
457
458    /// Base model
459    dataset_json: String,
460
461    /// Prefix of the mlp key (the part before the layer number: "a.b.c" in "a.b.c.0.mlp")
462    prefix: String,
463
464    /// Name of the mlp key (the part before the layer number: "mlp" in "a.b.c.0.mlp")
465    mlp: String,
466
467    /// Expert model ids
468    model_ids: Vec<String>,
469
470    /// Layer ids (zero indexed) of layers to apply AnyMoE to, if empty will use all
471    #[serde(default = "default_empty_vec_usize")]
472    layers: Vec<usize>,
473}
474
475#[derive(Deserialize)]
476pub struct TomlSelector {
477    /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
478    tokenizer_json: Option<String>,
479
480    /// Selected model
481    model: TomlModelSelected,
482
483    /// Speculative model selector
484    speculative: Option<SpeculativeTomlModelSelected>,
485
486    /// AnyMoE config
487    anymoe: Option<AnyMoeTomlModelSelected>,
488}
489
490#[derive(Clone)]
491struct TomlLoaderInnerParams {
492    chat_template: Option<String>,
493    no_kv_cache: bool,
494    tokenizer_json: Option<String>,
495    prompt_chunksize: Option<NonZeroUsize>,
496    jinja_explicit: Option<String>,
497}
498
499pub struct TomlLoaderArgs {
500    pub chat_template: Option<String>,
501    pub no_kv_cache: bool,
502    pub prompt_chunksize: Option<NonZeroUsize>,
503    pub jinja_explicit: Option<String>,
504}
505
506pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
507    match model.model {
508        TomlModelSelected::Plain { dtype, .. }
509        | TomlModelSelected::Lora { dtype, .. }
510        | TomlModelSelected::XLora { dtype, .. }
511        | TomlModelSelected::VisionPlain { dtype, .. }
512        | TomlModelSelected::GGUF { dtype, .. }
513        | TomlModelSelected::GGML { dtype, .. }
514        | TomlModelSelected::XLoraGGUF { dtype, .. }
515        | TomlModelSelected::XLoraGGML { dtype, .. }
516        | TomlModelSelected::LoraGGUF { dtype, .. }
517        | TomlModelSelected::LoraGGML { dtype, .. } => dtype,
518    }
519}
520
521pub fn get_toml_selected_model_device_map_params(
522    model: &TomlSelector,
523) -> anyhow::Result<AutoDeviceMapParams> {
524    match model.model {
525        TomlModelSelected::Plain {
526            max_seq_len,
527            max_batch_size,
528            ..
529        }
530        | TomlModelSelected::Lora {
531            max_seq_len,
532            max_batch_size,
533            ..
534        }
535        | TomlModelSelected::XLora {
536            max_seq_len,
537            max_batch_size,
538            ..
539        }
540        | TomlModelSelected::GGML {
541            max_seq_len,
542            max_batch_size,
543            ..
544        }
545        | TomlModelSelected::GGUF {
546            max_seq_len,
547            max_batch_size,
548            ..
549        }
550        | TomlModelSelected::XLoraGGUF {
551            max_seq_len,
552            max_batch_size,
553            ..
554        }
555        | TomlModelSelected::XLoraGGML {
556            max_seq_len,
557            max_batch_size,
558            ..
559        }
560        | TomlModelSelected::LoraGGUF {
561            max_seq_len,
562            max_batch_size,
563            ..
564        }
565        | TomlModelSelected::LoraGGML {
566            max_seq_len,
567            max_batch_size,
568            ..
569        } => Ok(AutoDeviceMapParams::Text {
570            max_seq_len,
571            max_batch_size,
572        }),
573        TomlModelSelected::VisionPlain {
574            max_seq_len,
575            max_batch_size,
576            max_image_length,
577            max_num_images,
578            ..
579        } => Ok(AutoDeviceMapParams::Vision {
580            max_seq_len,
581            max_batch_size,
582            max_image_shape: (max_image_length, max_image_length),
583            max_num_images,
584        }),
585    }
586}
587
588fn loader_from_selected(
589    args: TomlLoaderInnerParams,
590    model: TomlModelSelected,
591) -> anyhow::Result<Box<dyn Loader>> {
592    let loader: Box<dyn Loader> = match model {
593        TomlModelSelected::Plain {
594            model_id,
595            arch,
596            dtype: _,
597            topology,
598            organization,
599            write_uqff,
600            from_uqff,
601            imatrix,
602            calibration_file,
603            max_seq_len: _,
604            max_batch_size: _,
605            hf_cache_path,
606        } => NormalLoaderBuilder::new(
607            NormalSpecificConfig {
608                prompt_chunksize: args.prompt_chunksize,
609                topology: Topology::from_option_path(topology)?,
610                organization: organization.unwrap_or_default(),
611                write_uqff,
612                from_uqff: from_uqff.map(|x| {
613                    x.split(UQFF_MULTI_FILE_DELIMITER)
614                        .map(PathBuf::from_str)
615                        .map(|x| x.unwrap())
616                        .collect::<Vec<_>>()
617                }),
618                imatrix,
619                calibration_file,
620                hf_cache_path,
621            },
622            args.chat_template,
623            args.tokenizer_json,
624            Some(model_id),
625            args.no_kv_cache,
626            args.jinja_explicit,
627        )
628        .build(arch)?,
629        TomlModelSelected::XLora {
630            model_id,
631            xlora_model_id,
632            order,
633            tgt_non_granular_index,
634            arch,
635            dtype: _,
636            topology,
637            write_uqff,
638            from_uqff,
639            max_seq_len: _,
640            max_batch_size: _,
641            hf_cache_path,
642        } => NormalLoaderBuilder::new(
643            NormalSpecificConfig {
644                prompt_chunksize: args.prompt_chunksize,
645                topology: Topology::from_option_path(topology)?,
646                organization: Default::default(),
647                write_uqff,
648                from_uqff: from_uqff.map(|x| {
649                    x.split(UQFF_MULTI_FILE_DELIMITER)
650                        .map(PathBuf::from_str)
651                        .map(|x| x.unwrap())
652                        .collect::<Vec<_>>()
653                }),
654                imatrix: None,
655                calibration_file: None,
656                hf_cache_path,
657            },
658            args.chat_template,
659            args.tokenizer_json,
660            model_id,
661            args.no_kv_cache,
662            args.jinja_explicit,
663        )
664        .with_xlora(
665            xlora_model_id,
666            serde_json::from_reader(
667                File::open(order.clone())
668                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
669            )?,
670            args.no_kv_cache,
671            tgt_non_granular_index,
672        )
673        .build(arch)?,
674        TomlModelSelected::Lora {
675            model_id,
676            adapter_model_ids,
677            arch,
678            dtype: _,
679            topology,
680            write_uqff,
681            from_uqff,
682            max_seq_len: _,
683            max_batch_size: _,
684            hf_cache_path,
685        } => NormalLoaderBuilder::new(
686            NormalSpecificConfig {
687                prompt_chunksize: args.prompt_chunksize,
688                topology: Topology::from_option_path(topology)?,
689                organization: Default::default(),
690                write_uqff,
691                from_uqff: from_uqff.map(|x| {
692                    x.split(UQFF_MULTI_FILE_DELIMITER)
693                        .map(PathBuf::from_str)
694                        .map(|x| x.unwrap())
695                        .collect::<Vec<_>>()
696                }),
697                imatrix: None,
698                calibration_file: None,
699                hf_cache_path,
700            },
701            args.chat_template,
702            args.tokenizer_json,
703            model_id,
704            args.no_kv_cache,
705            args.jinja_explicit,
706        )
707        .with_lora(
708            adapter_model_ids
709                .split(MULTI_LORA_DELIMITER)
710                .map(ToString::to_string)
711                .collect(),
712        )
713        .build(arch)?,
714        TomlModelSelected::GGUF {
715            tok_model_id,
716            quantized_model_id,
717            quantized_filename,
718            topology,
719            dtype: _,
720            max_seq_len: _,
721            max_batch_size: _,
722        } => GGUFLoaderBuilder::new(
723            args.chat_template,
724            Some(tok_model_id),
725            quantized_model_id,
726            quantized_filename
727                .split(GGUF_MULTI_FILE_DELIMITER)
728                .map(ToOwned::to_owned)
729                .collect::<Vec<_>>(),
730            GGUFSpecificConfig {
731                prompt_chunksize: args.prompt_chunksize,
732                topology: Topology::from_option_path(topology)?,
733            },
734            args.no_kv_cache,
735            args.jinja_explicit,
736        )
737        .build(),
738        TomlModelSelected::XLoraGGUF {
739            tok_model_id,
740            quantized_model_id,
741            quantized_filename,
742            xlora_model_id,
743            order,
744            tgt_non_granular_index,
745            topology,
746            dtype: _,
747            max_seq_len: _,
748            max_batch_size: _,
749        } => GGUFLoaderBuilder::new(
750            args.chat_template,
751            tok_model_id,
752            quantized_model_id,
753            quantized_filename
754                .split(GGUF_MULTI_FILE_DELIMITER)
755                .map(ToOwned::to_owned)
756                .collect::<Vec<_>>(),
757            GGUFSpecificConfig {
758                prompt_chunksize: args.prompt_chunksize,
759                topology: Topology::from_option_path(topology)?,
760            },
761            args.no_kv_cache,
762            args.jinja_explicit,
763        )
764        .with_xlora(
765            xlora_model_id,
766            serde_json::from_reader(
767                File::open(order.clone())
768                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
769            )?,
770            args.no_kv_cache,
771            tgt_non_granular_index,
772        )
773        .build(),
774        TomlModelSelected::LoraGGUF {
775            tok_model_id,
776            quantized_model_id,
777            quantized_filename,
778            adapters_model_id,
779            order,
780            topology,
781            ..
782        } => GGUFLoaderBuilder::new(
783            args.chat_template,
784            tok_model_id,
785            quantized_model_id,
786            quantized_filename
787                .split(GGUF_MULTI_FILE_DELIMITER)
788                .map(ToOwned::to_owned)
789                .collect::<Vec<_>>(),
790            GGUFSpecificConfig {
791                prompt_chunksize: args.prompt_chunksize,
792                topology: Topology::from_option_path(topology)?,
793            },
794            args.no_kv_cache,
795            args.jinja_explicit,
796        )
797        .with_lora(
798            adapters_model_id,
799            serde_json::from_reader(
800                File::open(order.clone())
801                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
802            )?,
803        )
804        .build(),
805        TomlModelSelected::GGML {
806            tok_model_id,
807            quantized_model_id,
808            quantized_filename,
809            gqa,
810            topology,
811            dtype: _,
812            max_seq_len: _,
813            max_batch_size: _,
814        } => GGMLLoaderBuilder::new(
815            GGMLSpecificConfig {
816                gqa,
817                prompt_chunksize: args.prompt_chunksize,
818                topology: Topology::from_option_path(topology)?,
819            },
820            args.chat_template,
821            args.tokenizer_json,
822            Some(tok_model_id),
823            quantized_model_id,
824            quantized_filename,
825            args.no_kv_cache,
826            args.jinja_explicit,
827        )
828        .build(),
829        TomlModelSelected::XLoraGGML {
830            tok_model_id,
831            quantized_model_id,
832            quantized_filename,
833            xlora_model_id,
834            order,
835            tgt_non_granular_index,
836            gqa,
837            topology,
838            dtype: _,
839            max_seq_len: _,
840            max_batch_size: _,
841        } => GGMLLoaderBuilder::new(
842            GGMLSpecificConfig {
843                gqa,
844                prompt_chunksize: args.prompt_chunksize,
845                topology: Topology::from_option_path(topology)?,
846            },
847            args.chat_template,
848            args.tokenizer_json,
849            tok_model_id,
850            quantized_model_id,
851            quantized_filename,
852            args.no_kv_cache,
853            args.jinja_explicit,
854        )
855        .with_xlora(
856            xlora_model_id,
857            serde_json::from_reader(
858                File::open(order.clone())
859                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
860            )?,
861            args.no_kv_cache,
862            tgt_non_granular_index,
863        )
864        .build(),
865        TomlModelSelected::LoraGGML {
866            tok_model_id,
867            quantized_model_id,
868            quantized_filename,
869            adapters_model_id,
870            order,
871            gqa,
872            topology,
873            dtype: _,
874            max_seq_len: _,
875            max_batch_size: _,
876        } => GGMLLoaderBuilder::new(
877            GGMLSpecificConfig {
878                gqa,
879                prompt_chunksize: args.prompt_chunksize,
880                topology: Topology::from_option_path(topology)?,
881            },
882            args.chat_template,
883            args.tokenizer_json,
884            tok_model_id,
885            quantized_model_id,
886            quantized_filename,
887            args.no_kv_cache,
888            args.jinja_explicit,
889        )
890        .with_lora(
891            adapters_model_id,
892            serde_json::from_reader(
893                File::open(order.clone())
894                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
895            )?,
896        )
897        .build(),
898        TomlModelSelected::VisionPlain {
899            model_id,
900            arch,
901            dtype: _,
902            topology,
903            write_uqff,
904            from_uqff,
905            max_edge,
906            calibration_file,
907            max_seq_len: _,
908            max_batch_size: _,
909            max_num_images: _,
910            max_image_length: _,
911            imatrix,
912            hf_cache_path,
913        } => VisionLoaderBuilder::new(
914            VisionSpecificConfig {
915                prompt_chunksize: args.prompt_chunksize,
916                topology: Topology::from_option_path(topology)?,
917                write_uqff,
918                from_uqff: from_uqff.map(|x| {
919                    x.split(UQFF_MULTI_FILE_DELIMITER)
920                        .map(PathBuf::from_str)
921                        .map(|x| x.unwrap())
922                        .collect::<Vec<_>>()
923                }),
924                max_edge,
925                calibration_file,
926                imatrix,
927                hf_cache_path,
928            },
929            args.chat_template,
930            args.tokenizer_json,
931            Some(model_id),
932            args.jinja_explicit,
933        )
934        .build(arch),
935    };
936    Ok(loader)
937}
938
939impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
940    type Error = anyhow::Error;
941    fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
942        let (selector, args) = self;
943        let args = TomlLoaderInnerParams {
944            chat_template: args.chat_template,
945            no_kv_cache: args.no_kv_cache,
946            tokenizer_json: selector.tokenizer_json,
947            prompt_chunksize: args.prompt_chunksize,
948            jinja_explicit: args.jinja_explicit,
949        };
950        let loader = loader_from_selected(args.clone(), selector.model)?;
951        let loader = if let Some(speculative) = selector.speculative {
952            let draft_loader = loader_from_selected(args, speculative.draft_model)?;
953            Box::new(SpeculativeLoader {
954                target: loader,
955                draft: draft_loader,
956                config: SpeculativeConfig {
957                    gamma: speculative.gamma,
958                },
959            })
960        } else {
961            loader
962        };
963        let loader = if let Some(AnyMoeTomlModelSelected {
964            config,
965            dataset_json,
966            prefix,
967            mlp,
968            model_ids,
969            layers,
970        }) = selector.anymoe
971        {
972            Box::new(AnyMoeLoader {
973                target: loader,
974                config,
975                path: dataset_json,
976                prefix,
977                mlp,
978                model_ids,
979                layers,
980            })
981        } else {
982            loader
983        };
984        Ok(loader)
985    }
986}