mistralrs_core/
model_selected.rs

1use std::path::PathBuf;
2
3use clap::Subcommand;
4
5use crate::{
6    pipeline::{AutoDeviceMapParams, IsqOrganization, NormalLoaderType, VisionLoaderType},
7    DiffusionLoaderType, ModelDType, SpeechLoaderType,
8};
9
10fn parse_arch(x: &str) -> Result<NormalLoaderType, String> {
11    x.parse()
12}
13
14fn parse_vision_arch(x: &str) -> Result<VisionLoaderType, String> {
15    x.parse()
16}
17
18fn parse_diffusion_arch(x: &str) -> Result<DiffusionLoaderType, String> {
19    x.parse()
20}
21
22fn parse_speech_arch(x: &str) -> Result<SpeechLoaderType, String> {
23    x.parse()
24}
25
26fn parse_model_dtype(x: &str) -> Result<ModelDType, String> {
27    x.parse()
28}
29
30#[derive(Debug, Subcommand)]
31pub enum ModelSelected {
32    /// Select the model from a toml file
33    Toml {
34        /// .toml file containing the selector configuration.
35        #[arg(short, long)]
36        file: String,
37    },
38
39    /// Select a plain model, without quantization or adapters
40    Plain {
41        /// Model ID to load from. This may be a HF hub repo or a local path.
42        #[arg(short, long)]
43        model_id: String,
44
45        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
46        #[arg(short, long)]
47        tokenizer_json: Option<String>,
48
49        /// The architecture of the model.
50        #[arg(short, long, value_parser = parse_arch)]
51        arch: Option<NormalLoaderType>,
52
53        /// Model data type. Defaults to `auto`.
54        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
55        dtype: ModelDType,
56
57        /// Path to a topology YAML file.
58        #[arg(long)]
59        topology: Option<String>,
60
61        #[allow(rustdoc::bare_urls)]
62        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
63        #[arg(short, long)]
64        organization: Option<IsqOrganization>,
65
66        /// UQFF path to write to.
67        #[arg(short, long)]
68        write_uqff: Option<PathBuf>,
69
70        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;)
71        #[arg(short, long)]
72        from_uqff: Option<String>,
73
74        /// .imatrix file to enhance GGUF quantizations with.
75        /// Incompatible with `--calibration-file/-c`
76        #[arg(short, long)]
77        imatrix: Option<PathBuf>,
78
79        /// Generate and utilize an imatrix to enhance GGUF quantizations.
80        /// Incompatible with `--imatrix/-i`
81        #[arg(short, long)]
82        calibration_file: Option<PathBuf>,
83
84        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
85        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
86        max_seq_len: usize,
87
88        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
89        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
90        max_batch_size: usize,
91
92        /// Cache path for Hugging Face models downloaded locally
93        #[arg(long)]
94        hf_cache_path: Option<PathBuf>,
95    },
96
97    /// Select an X-LoRA architecture
98    XLora {
99        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
100        #[arg(short, long)]
101        model_id: Option<String>,
102
103        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
104        #[arg(short, long)]
105        tokenizer_json: Option<String>,
106
107        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
108        #[arg(short, long)]
109        xlora_model_id: String,
110
111        /// Ordering JSON file
112        #[arg(short, long)]
113        order: String,
114
115        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
116        /// This makes the maximum running sequences 1.
117        #[arg(long)]
118        tgt_non_granular_index: Option<usize>,
119
120        /// The architecture of the model.
121        #[arg(short, long, value_parser = parse_arch)]
122        arch: Option<NormalLoaderType>,
123
124        /// Model data type. Defaults to `auto`.
125        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
126        dtype: ModelDType,
127
128        /// Path to a topology YAML file.
129        #[arg(long)]
130        topology: Option<String>,
131
132        /// UQFF path to write to.
133        #[arg(short, long)]
134        write_uqff: Option<PathBuf>,
135
136        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
137        #[arg(short, long)]
138        from_uqff: Option<String>,
139
140        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
141        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
142        max_seq_len: usize,
143
144        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
145        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
146        max_batch_size: usize,
147
148        /// Cache path for Hugging Face models downloaded locally
149        #[arg(long)]
150        hf_cache_path: Option<PathBuf>,
151    },
152
153    /// Select a LoRA architecture
154    Lora {
155        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
156        #[arg(short, long)]
157        model_id: Option<String>,
158
159        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
160        #[arg(short, long)]
161        tokenizer_json: Option<String>,
162
163        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
164        #[arg(short, long)]
165        adapter_model_id: String,
166
167        /// The architecture of the model.
168        #[arg(long, value_parser = parse_arch)]
169        arch: Option<NormalLoaderType>,
170
171        /// Model data type. Defaults to `auto`.
172        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
173        dtype: ModelDType,
174
175        /// Path to a topology YAML file.
176        #[arg(long)]
177        topology: Option<String>,
178
179        /// UQFF path to write to.
180        #[arg(short, long)]
181        write_uqff: Option<PathBuf>,
182
183        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
184        #[arg(short, long)]
185        from_uqff: Option<String>,
186
187        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
188        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
189        max_seq_len: usize,
190
191        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
192        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
193        max_batch_size: usize,
194
195        /// Cache path for Hugging Face models downloaded locally
196        #[arg(long)]
197        hf_cache_path: Option<PathBuf>,
198    },
199
200    /// Select a GGUF model.
201    GGUF {
202        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
203        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
204        /// removing all remote accesses.
205        #[arg(short, long)]
206        tok_model_id: Option<String>,
207
208        /// Quantized model ID to find the `quantized_filename`.
209        /// This may be a HF hub repo or a local path.
210        #[arg(short = 'm', long)]
211        quantized_model_id: String,
212
213        /// Quantized filename(s).
214        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
215        #[arg(short = 'f', long)]
216        quantized_filename: String,
217
218        /// Model data type. Defaults to `auto`.
219        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
220        dtype: ModelDType,
221
222        /// Path to a topology YAML file.
223        #[arg(long)]
224        topology: Option<String>,
225
226        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
227        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
228        max_seq_len: usize,
229
230        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
231        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
232        max_batch_size: usize,
233    },
234
235    /// Select a GGUF model with X-LoRA.
236    XLoraGGUF {
237        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
238        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
239        /// removing all remote accesses.
240        #[arg(short, long)]
241        tok_model_id: Option<String>,
242
243        /// Quantized model ID to find the `quantized_filename`.
244        /// This may be a HF hub repo or a local path.
245        #[arg(short = 'm', long)]
246        quantized_model_id: String,
247
248        /// Quantized filename(s).
249        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
250        #[arg(short = 'f', long)]
251        quantized_filename: String,
252
253        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
254        #[arg(short, long)]
255        xlora_model_id: String,
256
257        /// Ordering JSON file
258        #[arg(short, long)]
259        order: String,
260
261        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
262        /// This makes the maximum running sequences 1.
263        #[arg(long)]
264        tgt_non_granular_index: Option<usize>,
265
266        /// Model data type. Defaults to `auto`.
267        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
268        dtype: ModelDType,
269
270        /// Path to a topology YAML file.
271        #[arg(long)]
272        topology: Option<String>,
273
274        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
275        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
276        max_seq_len: usize,
277
278        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
279        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
280        max_batch_size: usize,
281    },
282
283    /// Select a GGUF model with LoRA.
284    LoraGGUF {
285        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
286        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
287        /// removing all remote accesses.
288        #[arg(short, long)]
289        tok_model_id: Option<String>,
290
291        /// Quantized model ID to find the `quantized_filename`.
292        /// This may be a HF hub repo or a local path.
293        #[arg(short = 'm', long)]
294        quantized_model_id: String,
295
296        /// Quantized filename(s).
297        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
298        #[arg(short = 'f', long)]
299        quantized_filename: String,
300
301        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
302        #[arg(short, long)]
303        adapters_model_id: String,
304
305        /// Ordering JSON file
306        #[arg(short, long)]
307        order: String,
308
309        /// Model data type. Defaults to `auto`.
310        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
311        dtype: ModelDType,
312
313        /// Path to a topology YAML file.
314        #[arg(long)]
315        topology: Option<String>,
316
317        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
318        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
319        max_seq_len: usize,
320
321        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
322        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
323        max_batch_size: usize,
324    },
325
326    /// Select a GGML model.
327    GGML {
328        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
329        #[arg(short, long)]
330        tok_model_id: String,
331
332        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
333        #[arg(long)]
334        tokenizer_json: Option<String>,
335
336        /// Quantized model ID to find the `quantized_filename`.
337        /// This may be a HF hub repo or a local path.
338        #[arg(short = 'm', long)]
339        quantized_model_id: String,
340
341        /// Quantized filename.
342        #[arg(short = 'f', long)]
343        quantized_filename: String,
344
345        /// GQA value
346        #[arg(short, long, default_value_t = 1)]
347        gqa: usize,
348
349        /// Model data type. Defaults to `auto`.
350        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
351        dtype: ModelDType,
352
353        /// Path to a topology YAML file.
354        #[arg(long)]
355        topology: Option<String>,
356
357        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
358        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
359        max_seq_len: usize,
360
361        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
362        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
363        max_batch_size: usize,
364    },
365
366    /// Select a GGML model with X-LoRA.
367    XLoraGGML {
368        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
369        #[arg(short, long)]
370        tok_model_id: Option<String>,
371
372        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
373        #[arg(long)]
374        tokenizer_json: Option<String>,
375
376        /// Quantized model ID to find the `quantized_filename`.
377        /// This may be a HF hub repo or a local path.
378        #[arg(short = 'm', long)]
379        quantized_model_id: String,
380
381        /// Quantized filename.
382        #[arg(short = 'f', long)]
383        quantized_filename: String,
384
385        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
386        #[arg(short, long)]
387        xlora_model_id: String,
388
389        /// Ordering JSON file
390        #[arg(short, long)]
391        order: String,
392
393        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
394        /// This makes the maximum running sequences 1.
395        #[arg(long)]
396        tgt_non_granular_index: Option<usize>,
397
398        /// GQA value
399        #[arg(short, long, default_value_t = 1)]
400        gqa: usize,
401
402        /// Model data type. Defaults to `auto`.
403        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
404        dtype: ModelDType,
405
406        /// Path to a topology YAML file.
407        #[arg(long)]
408        topology: Option<String>,
409
410        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
411        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
412        max_seq_len: usize,
413
414        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
415        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
416        max_batch_size: usize,
417    },
418
419    /// Select a GGML model with LoRA.
420    LoraGGML {
421        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
422        #[arg(short, long)]
423        tok_model_id: Option<String>,
424
425        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
426        #[arg(long)]
427        tokenizer_json: Option<String>,
428
429        /// Quantized model ID to find the `quantized_filename`.
430        /// This may be a HF hub repo or a local path.
431        #[arg(short = 'm', long)]
432        quantized_model_id: String,
433
434        /// Quantized filename.
435        #[arg(short = 'f', long)]
436        quantized_filename: String,
437
438        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
439        #[arg(short, long)]
440        adapters_model_id: String,
441
442        /// Ordering JSON file
443        #[arg(short, long)]
444        order: String,
445
446        /// GQA value
447        #[arg(short, long, default_value_t = 1)]
448        gqa: usize,
449
450        /// Model data type. Defaults to `auto`.
451        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
452        dtype: ModelDType,
453
454        /// Path to a topology YAML file.
455        #[arg(long)]
456        topology: Option<String>,
457
458        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
459        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
460        max_seq_len: usize,
461
462        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
463        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
464        max_batch_size: usize,
465    },
466
467    /// Select a vision plain model, without quantization or adapters
468    VisionPlain {
469        /// Model ID to load from. This may be a HF hub repo or a local path.
470        #[arg(short, long)]
471        model_id: String,
472
473        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
474        #[arg(short, long)]
475        tokenizer_json: Option<String>,
476
477        /// The architecture of the model.
478        #[arg(short, long, value_parser = parse_vision_arch)]
479        arch: Option<VisionLoaderType>,
480
481        /// Model data type. Defaults to `auto`.
482        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
483        dtype: ModelDType,
484
485        /// Path to a topology YAML file.
486        #[arg(long)]
487        topology: Option<String>,
488
489        /// UQFF path to write to.
490        #[arg(short, long)]
491        write_uqff: Option<PathBuf>,
492
493        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
494        #[arg(short, long)]
495        from_uqff: Option<String>,
496
497        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
498        /// This is only supported on the Qwen2-VL and Idefics models. Others handle this internally.
499        #[arg(short = 'e', long)]
500        max_edge: Option<u32>,
501
502        /// Generate and utilize an imatrix to enhance GGUF quantizations.
503        #[arg(short, long)]
504        calibration_file: Option<PathBuf>,
505
506        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
507        #[arg(short, long)]
508        imatrix: Option<PathBuf>,
509
510        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
511        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
512        max_seq_len: usize,
513
514        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
515        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
516        max_batch_size: usize,
517
518        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
519        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES)]
520        max_num_images: usize,
521
522        /// Maximum expected image size will have this edge length on both edges.
523        /// This affects automatic device mapping but is not a hard limit.
524        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH)]
525        max_image_length: usize,
526
527        /// Cache path for Hugging Face models downloaded locally
528        #[arg(long)]
529        hf_cache_path: Option<PathBuf>,
530    },
531
532    /// Select a diffusion plain model, without quantization or adapters
533    DiffusionPlain {
534        /// Model ID to load from. This may be a HF hub repo or a local path.
535        #[arg(short, long)]
536        model_id: String,
537
538        /// The architecture of the model.
539        #[arg(short, long, value_parser = parse_diffusion_arch)]
540        arch: DiffusionLoaderType,
541
542        /// Model data type. Defaults to `auto`.
543        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
544        dtype: ModelDType,
545    },
546
547    Speech {
548        /// Model ID to load from. This may be a HF hub repo or a local path.
549        #[arg(short, long)]
550        model_id: String,
551
552        /// DAC Model ID to load from. If not provided, this is automatically downloaded from the default path for the model.
553        /// This may be a HF hub repo or a local path.
554        #[arg(short, long)]
555        dac_model_id: Option<String>,
556
557        /// The architecture of the model.
558        #[arg(short, long, value_parser = parse_speech_arch)]
559        arch: SpeechLoaderType,
560
561        /// Model data type. Defaults to `auto`.
562        #[arg(long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
563        dtype: ModelDType,
564    },
565}