mistralrs_core/
model_selected.rs

1use std::path::PathBuf;
2
3use clap::Subcommand;
4
5use crate::{
6    pipeline::{AutoDeviceMapParams, IsqOrganization, NormalLoaderType, VisionLoaderType},
7    DiffusionLoaderType, ModelDType,
8};
9
10fn parse_arch(x: &str) -> Result<NormalLoaderType, String> {
11    x.parse()
12}
13
14fn parse_vision_arch(x: &str) -> Result<VisionLoaderType, String> {
15    x.parse()
16}
17
18fn parse_diffusion_arch(x: &str) -> Result<DiffusionLoaderType, String> {
19    x.parse()
20}
21
22fn parse_model_dtype(x: &str) -> Result<ModelDType, String> {
23    x.parse()
24}
25
26#[derive(Debug, Subcommand)]
27pub enum ModelSelected {
28    /// Select the model from a toml file
29    Toml {
30        /// .toml file containing the selector configuration.
31        #[arg(short, long)]
32        file: String,
33    },
34
35    /// Select a plain model, without quantization or adapters
36    Plain {
37        /// Model ID to load from. This may be a HF hub repo or a local path.
38        #[arg(short, long)]
39        model_id: String,
40
41        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
42        #[arg(short, long)]
43        tokenizer_json: Option<String>,
44
45        /// The architecture of the model.
46        #[arg(short, long, value_parser = parse_arch)]
47        arch: Option<NormalLoaderType>,
48
49        /// Model data type. Defaults to `auto`.
50        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
51        dtype: ModelDType,
52
53        /// Path to a topology YAML file.
54        #[arg(long)]
55        topology: Option<String>,
56
57        #[allow(rustdoc::bare_urls)]
58        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
59        #[arg(short, long)]
60        organization: Option<IsqOrganization>,
61
62        /// UQFF path to write to.
63        #[arg(short, long)]
64        write_uqff: Option<PathBuf>,
65
66        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
67        #[arg(short, long)]
68        from_uqff: Option<PathBuf>,
69
70        /// .imatrix file to enhance GGUF quantizations with.
71        /// Incompatible with `--calibration-file/-c`
72        #[arg(short, long)]
73        imatrix: Option<PathBuf>,
74
75        /// Generate and utilize an imatrix to enhance GGUF quantizations.
76        /// Incompatible with `--imatrix/-i`
77        #[arg(short, long)]
78        calibration_file: Option<PathBuf>,
79
80        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
81        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
82        max_seq_len: usize,
83
84        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
85        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
86        max_batch_size: usize,
87
88        /// Cache path for Hugging Face models downloaded locally
89        #[arg(short, long)]
90        hf_cache_path: Option<PathBuf>,
91    },
92
93    /// Select an X-LoRA architecture
94    XLora {
95        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
96        #[arg(short, long)]
97        model_id: Option<String>,
98
99        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
100        #[arg(short, long)]
101        tokenizer_json: Option<String>,
102
103        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
104        #[arg(short, long)]
105        xlora_model_id: String,
106
107        /// Ordering JSON file
108        #[arg(short, long)]
109        order: String,
110
111        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
112        /// This makes the maximum running sequences 1.
113        #[arg(long)]
114        tgt_non_granular_index: Option<usize>,
115
116        /// The architecture of the model.
117        #[arg(short, long, value_parser = parse_arch)]
118        arch: Option<NormalLoaderType>,
119
120        /// Model data type. Defaults to `auto`.
121        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
122        dtype: ModelDType,
123
124        /// Path to a topology YAML file.
125        #[arg(long)]
126        topology: Option<String>,
127
128        /// UQFF path to write to.
129        #[arg(short, long)]
130        write_uqff: Option<PathBuf>,
131
132        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
133        #[arg(short, long)]
134        from_uqff: Option<PathBuf>,
135
136        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
137        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
138        max_seq_len: usize,
139
140        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
141        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
142        max_batch_size: usize,
143
144        /// Cache path for Hugging Face models downloaded locally
145        #[arg(short, long)]
146        hf_cache_path: Option<PathBuf>,
147    },
148
149    /// Select a LoRA architecture
150    Lora {
151        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
152        #[arg(short, long)]
153        model_id: Option<String>,
154
155        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
156        #[arg(short, long)]
157        tokenizer_json: Option<String>,
158
159        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
160        #[arg(short, long)]
161        adapter_model_id: String,
162
163        /// The architecture of the model.
164        #[arg(long, value_parser = parse_arch)]
165        arch: Option<NormalLoaderType>,
166
167        /// Model data type. Defaults to `auto`.
168        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
169        dtype: ModelDType,
170
171        /// Path to a topology YAML file.
172        #[arg(long)]
173        topology: Option<String>,
174
175        /// UQFF path to write to.
176        #[arg(short, long)]
177        write_uqff: Option<PathBuf>,
178
179        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
180        #[arg(short, long)]
181        from_uqff: Option<PathBuf>,
182
183        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
184        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
185        max_seq_len: usize,
186
187        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
188        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
189        max_batch_size: usize,
190
191        /// Cache path for Hugging Face models downloaded locally
192        #[arg(short, long)]
193        hf_cache_path: Option<PathBuf>,
194    },
195
196    /// Select a GGUF model.
197    GGUF {
198        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
199        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
200        /// removing all remote accesses.
201        #[arg(short, long)]
202        tok_model_id: Option<String>,
203
204        /// Quantized model ID to find the `quantized_filename`.
205        /// This may be a HF hub repo or a local path.
206        #[arg(short = 'm', long)]
207        quantized_model_id: String,
208
209        /// Quantized filename(s).
210        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
211        #[arg(short = 'f', long)]
212        quantized_filename: String,
213
214        /// Model data type. Defaults to `auto`.
215        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
216        dtype: ModelDType,
217
218        /// Path to a topology YAML file.
219        #[arg(long)]
220        topology: Option<String>,
221
222        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
223        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
224        max_seq_len: usize,
225
226        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
227        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
228        max_batch_size: usize,
229    },
230
231    /// Select a GGUF model with X-LoRA.
232    XLoraGGUF {
233        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
234        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
235        /// removing all remote accesses.
236        #[arg(short, long)]
237        tok_model_id: Option<String>,
238
239        /// Quantized model ID to find the `quantized_filename`.
240        /// This may be a HF hub repo or a local path.
241        #[arg(short = 'm', long)]
242        quantized_model_id: String,
243
244        /// Quantized filename(s).
245        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
246        #[arg(short = 'f', long)]
247        quantized_filename: String,
248
249        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
250        #[arg(short, long)]
251        xlora_model_id: String,
252
253        /// Ordering JSON file
254        #[arg(short, long)]
255        order: String,
256
257        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
258        /// This makes the maximum running sequences 1.
259        #[arg(long)]
260        tgt_non_granular_index: Option<usize>,
261
262        /// Model data type. Defaults to `auto`.
263        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
264        dtype: ModelDType,
265
266        /// Path to a topology YAML file.
267        #[arg(long)]
268        topology: Option<String>,
269
270        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
271        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
272        max_seq_len: usize,
273
274        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
275        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
276        max_batch_size: usize,
277    },
278
279    /// Select a GGUF model with LoRA.
280    LoraGGUF {
281        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
282        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
283        /// removing all remote accesses.
284        #[arg(short, long)]
285        tok_model_id: Option<String>,
286
287        /// Quantized model ID to find the `quantized_filename`.
288        /// This may be a HF hub repo or a local path.
289        #[arg(short = 'm', long)]
290        quantized_model_id: String,
291
292        /// Quantized filename(s).
293        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
294        #[arg(short = 'f', long)]
295        quantized_filename: String,
296
297        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
298        #[arg(short, long)]
299        adapters_model_id: String,
300
301        /// Ordering JSON file
302        #[arg(short, long)]
303        order: String,
304
305        /// Model data type. Defaults to `auto`.
306        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
307        dtype: ModelDType,
308
309        /// Path to a topology YAML file.
310        #[arg(long)]
311        topology: Option<String>,
312
313        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
314        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
315        max_seq_len: usize,
316
317        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
318        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
319        max_batch_size: usize,
320    },
321
322    /// Select a GGML model.
323    GGML {
324        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
325        #[arg(short, long)]
326        tok_model_id: String,
327
328        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
329        #[arg(long)]
330        tokenizer_json: Option<String>,
331
332        /// Quantized model ID to find the `quantized_filename`.
333        /// This may be a HF hub repo or a local path.
334        #[arg(short = 'm', long)]
335        quantized_model_id: String,
336
337        /// Quantized filename.
338        #[arg(short = 'f', long)]
339        quantized_filename: String,
340
341        /// GQA value
342        #[arg(short, long, default_value_t = 1)]
343        gqa: usize,
344
345        /// Model data type. Defaults to `auto`.
346        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
347        dtype: ModelDType,
348
349        /// Path to a topology YAML file.
350        #[arg(long)]
351        topology: Option<String>,
352
353        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
354        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
355        max_seq_len: usize,
356
357        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
358        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
359        max_batch_size: usize,
360    },
361
362    /// Select a GGML model with X-LoRA.
363    XLoraGGML {
364        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
365        #[arg(short, long)]
366        tok_model_id: Option<String>,
367
368        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
369        #[arg(long)]
370        tokenizer_json: Option<String>,
371
372        /// Quantized model ID to find the `quantized_filename`.
373        /// This may be a HF hub repo or a local path.
374        #[arg(short = 'm', long)]
375        quantized_model_id: String,
376
377        /// Quantized filename.
378        #[arg(short = 'f', long)]
379        quantized_filename: String,
380
381        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
382        #[arg(short, long)]
383        xlora_model_id: String,
384
385        /// Ordering JSON file
386        #[arg(short, long)]
387        order: String,
388
389        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
390        /// This makes the maximum running sequences 1.
391        #[arg(long)]
392        tgt_non_granular_index: Option<usize>,
393
394        /// GQA value
395        #[arg(short, long, default_value_t = 1)]
396        gqa: usize,
397
398        /// Model data type. Defaults to `auto`.
399        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
400        dtype: ModelDType,
401
402        /// Path to a topology YAML file.
403        #[arg(long)]
404        topology: Option<String>,
405
406        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
407        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
408        max_seq_len: usize,
409
410        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
411        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
412        max_batch_size: usize,
413    },
414
415    /// Select a GGML model with LoRA.
416    LoraGGML {
417        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
418        #[arg(short, long)]
419        tok_model_id: Option<String>,
420
421        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
422        #[arg(long)]
423        tokenizer_json: Option<String>,
424
425        /// Quantized model ID to find the `quantized_filename`.
426        /// This may be a HF hub repo or a local path.
427        #[arg(short = 'm', long)]
428        quantized_model_id: String,
429
430        /// Quantized filename.
431        #[arg(short = 'f', long)]
432        quantized_filename: String,
433
434        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
435        #[arg(short, long)]
436        adapters_model_id: String,
437
438        /// Ordering JSON file
439        #[arg(short, long)]
440        order: String,
441
442        /// GQA value
443        #[arg(short, long, default_value_t = 1)]
444        gqa: usize,
445
446        /// Model data type. Defaults to `auto`.
447        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
448        dtype: ModelDType,
449
450        /// Path to a topology YAML file.
451        #[arg(long)]
452        topology: Option<String>,
453
454        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
455        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
456        max_seq_len: usize,
457
458        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
459        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
460        max_batch_size: usize,
461    },
462
463    /// Select a vision plain model, without quantization or adapters
464    VisionPlain {
465        /// Model ID to load from. This may be a HF hub repo or a local path.
466        #[arg(short, long)]
467        model_id: String,
468
469        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
470        #[arg(short, long)]
471        tokenizer_json: Option<String>,
472
473        /// The architecture of the model.
474        #[arg(short, long, value_parser = parse_vision_arch)]
475        arch: VisionLoaderType,
476
477        /// Model data type. Defaults to `auto`.
478        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
479        dtype: ModelDType,
480
481        /// Path to a topology YAML file.
482        #[arg(long)]
483        topology: Option<String>,
484
485        /// UQFF path to write to.
486        #[arg(short, long)]
487        write_uqff: Option<PathBuf>,
488
489        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
490        #[arg(short, long)]
491        from_uqff: Option<PathBuf>,
492
493        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
494        /// This is only supported on the Qwen2-VL and Idefics models. Others handle this internally.
495        #[arg(short = 'e', long)]
496        max_edge: Option<u32>,
497
498        /// Generate and utilize an imatrix to enhance GGUF quantizations.
499        #[arg(short, long)]
500        calibration_file: Option<PathBuf>,
501
502        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
503        #[arg(short, long)]
504        imatrix: Option<PathBuf>,
505
506        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
507        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
508        max_seq_len: usize,
509
510        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
511        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
512        max_batch_size: usize,
513
514        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
515        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES)]
516        max_num_images: usize,
517
518        /// Maximum expected image size will have this edge length on both edges.
519        /// This affects automatic device mapping but is not a hard limit.
520        #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH)]
521        max_image_length: usize,
522
523        /// Cache path for Hugging Face models downloaded locally
524        #[arg(short, long)]
525        hf_cache_path: Option<PathBuf>,
526    },
527
528    /// Select a diffusion plain model, without quantization or adapters
529    DiffusionPlain {
530        /// Model ID to load from. This may be a HF hub repo or a local path.
531        #[arg(short, long)]
532        model_id: String,
533
534        /// The architecture of the model.
535        #[arg(short, long, value_parser = parse_diffusion_arch)]
536        arch: DiffusionLoaderType,
537
538        /// Model data type. Defaults to `auto`.
539        #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
540        dtype: ModelDType,
541    },
542}