mistralrs_core/model_selected.rs
1use std::path::PathBuf;
2
3use clap::Subcommand;
4
5use crate::{
6 pipeline::{AutoDeviceMapParams, IsqOrganization, NormalLoaderType, VisionLoaderType},
7 DiffusionLoaderType, ModelDType, SpeechLoaderType,
8};
9
10fn parse_arch(x: &str) -> Result<NormalLoaderType, String> {
11 x.parse()
12}
13
14fn parse_vision_arch(x: &str) -> Result<VisionLoaderType, String> {
15 x.parse()
16}
17
18fn parse_diffusion_arch(x: &str) -> Result<DiffusionLoaderType, String> {
19 x.parse()
20}
21
22fn parse_speech_arch(x: &str) -> Result<SpeechLoaderType, String> {
23 x.parse()
24}
25
26fn parse_model_dtype(x: &str) -> Result<ModelDType, String> {
27 x.parse()
28}
29
30#[derive(Debug, Subcommand)]
31pub enum ModelSelected {
32 /// Select the model from a toml file
33 Toml {
34 /// .toml file containing the selector configuration.
35 #[arg(short, long)]
36 file: String,
37 },
38
39 /// Select a plain model, without quantization or adapters
40 Plain {
41 /// Model ID to load from. This may be a HF hub repo or a local path.
42 #[arg(short, long)]
43 model_id: String,
44
45 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
46 #[arg(short, long)]
47 tokenizer_json: Option<String>,
48
49 /// The architecture of the model.
50 #[arg(short, long, value_parser = parse_arch)]
51 arch: Option<NormalLoaderType>,
52
53 /// Model data type. Defaults to `auto`.
54 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
55 dtype: ModelDType,
56
57 /// Path to a topology YAML file.
58 #[arg(long)]
59 topology: Option<String>,
60
61 #[allow(rustdoc::bare_urls)]
62 /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
63 #[arg(short, long)]
64 organization: Option<IsqOrganization>,
65
66 /// UQFF path to write to.
67 #[arg(short, long)]
68 write_uqff: Option<PathBuf>,
69
70 /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;)
71 #[arg(short, long)]
72 from_uqff: Option<String>,
73
74 /// .imatrix file to enhance GGUF quantizations with.
75 /// Incompatible with `--calibration-file/-c`
76 #[arg(short, long)]
77 imatrix: Option<PathBuf>,
78
79 /// Generate and utilize an imatrix to enhance GGUF quantizations.
80 /// Incompatible with `--imatrix/-i`
81 #[arg(short, long)]
82 calibration_file: Option<PathBuf>,
83
84 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
85 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
86 max_seq_len: usize,
87
88 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
89 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
90 max_batch_size: usize,
91
92 /// Cache path for Hugging Face models downloaded locally
93 #[arg(long)]
94 hf_cache_path: Option<PathBuf>,
95 },
96
97 /// Select an X-LoRA architecture
98 XLora {
99 /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
100 #[arg(short, long)]
101 model_id: Option<String>,
102
103 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
104 #[arg(short, long)]
105 tokenizer_json: Option<String>,
106
107 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
108 #[arg(short, long)]
109 xlora_model_id: String,
110
111 /// Ordering JSON file
112 #[arg(short, long)]
113 order: String,
114
115 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
116 /// This makes the maximum running sequences 1.
117 #[arg(long)]
118 tgt_non_granular_index: Option<usize>,
119
120 /// The architecture of the model.
121 #[arg(short, long, value_parser = parse_arch)]
122 arch: Option<NormalLoaderType>,
123
124 /// Model data type. Defaults to `auto`.
125 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
126 dtype: ModelDType,
127
128 /// Path to a topology YAML file.
129 #[arg(long)]
130 topology: Option<String>,
131
132 /// UQFF path to write to.
133 #[arg(short, long)]
134 write_uqff: Option<PathBuf>,
135
136 /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
137 #[arg(short, long)]
138 from_uqff: Option<String>,
139
140 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
141 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
142 max_seq_len: usize,
143
144 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
145 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
146 max_batch_size: usize,
147
148 /// Cache path for Hugging Face models downloaded locally
149 #[arg(long)]
150 hf_cache_path: Option<PathBuf>,
151 },
152
153 /// Select a LoRA architecture
154 Lora {
155 /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
156 #[arg(short, long)]
157 model_id: Option<String>,
158
159 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
160 #[arg(short, long)]
161 tokenizer_json: Option<String>,
162
163 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
164 #[arg(short, long)]
165 adapter_model_id: String,
166
167 /// The architecture of the model.
168 #[arg(long, value_parser = parse_arch)]
169 arch: Option<NormalLoaderType>,
170
171 /// Model data type. Defaults to `auto`.
172 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
173 dtype: ModelDType,
174
175 /// Path to a topology YAML file.
176 #[arg(long)]
177 topology: Option<String>,
178
179 /// UQFF path to write to.
180 #[arg(short, long)]
181 write_uqff: Option<PathBuf>,
182
183 /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
184 #[arg(short, long)]
185 from_uqff: Option<String>,
186
187 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
188 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
189 max_seq_len: usize,
190
191 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
192 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
193 max_batch_size: usize,
194
195 /// Cache path for Hugging Face models downloaded locally
196 #[arg(long)]
197 hf_cache_path: Option<PathBuf>,
198 },
199
200 /// Select a GGUF model.
201 GGUF {
202 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
203 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
204 /// removing all remote accesses.
205 #[arg(short, long)]
206 tok_model_id: Option<String>,
207
208 /// Quantized model ID to find the `quantized_filename`.
209 /// This may be a HF hub repo or a local path.
210 #[arg(short = 'm', long)]
211 quantized_model_id: String,
212
213 /// Quantized filename(s).
214 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
215 #[arg(short = 'f', long)]
216 quantized_filename: String,
217
218 /// Model data type. Defaults to `auto`.
219 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
220 dtype: ModelDType,
221
222 /// Path to a topology YAML file.
223 #[arg(long)]
224 topology: Option<String>,
225
226 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
227 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
228 max_seq_len: usize,
229
230 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
231 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
232 max_batch_size: usize,
233 },
234
235 /// Select a GGUF model with X-LoRA.
236 XLoraGGUF {
237 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
238 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
239 /// removing all remote accesses.
240 #[arg(short, long)]
241 tok_model_id: Option<String>,
242
243 /// Quantized model ID to find the `quantized_filename`.
244 /// This may be a HF hub repo or a local path.
245 #[arg(short = 'm', long)]
246 quantized_model_id: String,
247
248 /// Quantized filename(s).
249 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
250 #[arg(short = 'f', long)]
251 quantized_filename: String,
252
253 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
254 #[arg(short, long)]
255 xlora_model_id: String,
256
257 /// Ordering JSON file
258 #[arg(short, long)]
259 order: String,
260
261 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
262 /// This makes the maximum running sequences 1.
263 #[arg(long)]
264 tgt_non_granular_index: Option<usize>,
265
266 /// Model data type. Defaults to `auto`.
267 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
268 dtype: ModelDType,
269
270 /// Path to a topology YAML file.
271 #[arg(long)]
272 topology: Option<String>,
273
274 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
275 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
276 max_seq_len: usize,
277
278 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
279 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
280 max_batch_size: usize,
281 },
282
283 /// Select a GGUF model with LoRA.
284 LoraGGUF {
285 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
286 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
287 /// removing all remote accesses.
288 #[arg(short, long)]
289 tok_model_id: Option<String>,
290
291 /// Quantized model ID to find the `quantized_filename`.
292 /// This may be a HF hub repo or a local path.
293 #[arg(short = 'm', long)]
294 quantized_model_id: String,
295
296 /// Quantized filename(s).
297 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
298 #[arg(short = 'f', long)]
299 quantized_filename: String,
300
301 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
302 #[arg(short, long)]
303 adapters_model_id: String,
304
305 /// Ordering JSON file
306 #[arg(short, long)]
307 order: String,
308
309 /// Model data type. Defaults to `auto`.
310 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
311 dtype: ModelDType,
312
313 /// Path to a topology YAML file.
314 #[arg(long)]
315 topology: Option<String>,
316
317 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
318 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
319 max_seq_len: usize,
320
321 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
322 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
323 max_batch_size: usize,
324 },
325
326 /// Select a GGML model.
327 GGML {
328 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
329 #[arg(short, long)]
330 tok_model_id: String,
331
332 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
333 #[arg(long)]
334 tokenizer_json: Option<String>,
335
336 /// Quantized model ID to find the `quantized_filename`.
337 /// This may be a HF hub repo or a local path.
338 #[arg(short = 'm', long)]
339 quantized_model_id: String,
340
341 /// Quantized filename.
342 #[arg(short = 'f', long)]
343 quantized_filename: String,
344
345 /// GQA value
346 #[arg(short, long, default_value_t = 1)]
347 gqa: usize,
348
349 /// Model data type. Defaults to `auto`.
350 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
351 dtype: ModelDType,
352
353 /// Path to a topology YAML file.
354 #[arg(long)]
355 topology: Option<String>,
356
357 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
358 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
359 max_seq_len: usize,
360
361 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
362 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
363 max_batch_size: usize,
364 },
365
366 /// Select a GGML model with X-LoRA.
367 XLoraGGML {
368 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
369 #[arg(short, long)]
370 tok_model_id: Option<String>,
371
372 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
373 #[arg(long)]
374 tokenizer_json: Option<String>,
375
376 /// Quantized model ID to find the `quantized_filename`.
377 /// This may be a HF hub repo or a local path.
378 #[arg(short = 'm', long)]
379 quantized_model_id: String,
380
381 /// Quantized filename.
382 #[arg(short = 'f', long)]
383 quantized_filename: String,
384
385 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
386 #[arg(short, long)]
387 xlora_model_id: String,
388
389 /// Ordering JSON file
390 #[arg(short, long)]
391 order: String,
392
393 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
394 /// This makes the maximum running sequences 1.
395 #[arg(long)]
396 tgt_non_granular_index: Option<usize>,
397
398 /// GQA value
399 #[arg(short, long, default_value_t = 1)]
400 gqa: usize,
401
402 /// Model data type. Defaults to `auto`.
403 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
404 dtype: ModelDType,
405
406 /// Path to a topology YAML file.
407 #[arg(long)]
408 topology: Option<String>,
409
410 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
411 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
412 max_seq_len: usize,
413
414 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
415 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
416 max_batch_size: usize,
417 },
418
419 /// Select a GGML model with LoRA.
420 LoraGGML {
421 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
422 #[arg(short, long)]
423 tok_model_id: Option<String>,
424
425 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
426 #[arg(long)]
427 tokenizer_json: Option<String>,
428
429 /// Quantized model ID to find the `quantized_filename`.
430 /// This may be a HF hub repo or a local path.
431 #[arg(short = 'm', long)]
432 quantized_model_id: String,
433
434 /// Quantized filename.
435 #[arg(short = 'f', long)]
436 quantized_filename: String,
437
438 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
439 #[arg(short, long)]
440 adapters_model_id: String,
441
442 /// Ordering JSON file
443 #[arg(short, long)]
444 order: String,
445
446 /// GQA value
447 #[arg(short, long, default_value_t = 1)]
448 gqa: usize,
449
450 /// Model data type. Defaults to `auto`.
451 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
452 dtype: ModelDType,
453
454 /// Path to a topology YAML file.
455 #[arg(long)]
456 topology: Option<String>,
457
458 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
459 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
460 max_seq_len: usize,
461
462 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
463 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
464 max_batch_size: usize,
465 },
466
467 /// Select a vision plain model, without quantization or adapters
468 VisionPlain {
469 /// Model ID to load from. This may be a HF hub repo or a local path.
470 #[arg(short, long)]
471 model_id: String,
472
473 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
474 #[arg(short, long)]
475 tokenizer_json: Option<String>,
476
477 /// The architecture of the model.
478 #[arg(short, long, value_parser = parse_vision_arch)]
479 arch: Option<VisionLoaderType>,
480
481 /// Model data type. Defaults to `auto`.
482 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
483 dtype: ModelDType,
484
485 /// Path to a topology YAML file.
486 #[arg(long)]
487 topology: Option<String>,
488
489 /// UQFF path to write to.
490 #[arg(short, long)]
491 write_uqff: Option<PathBuf>,
492
493 /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;).
494 #[arg(short, long)]
495 from_uqff: Option<String>,
496
497 /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
498 /// This is only supported on the Qwen2-VL and Idefics models. Others handle this internally.
499 #[arg(short = 'e', long)]
500 max_edge: Option<u32>,
501
502 /// Generate and utilize an imatrix to enhance GGUF quantizations.
503 #[arg(short, long)]
504 calibration_file: Option<PathBuf>,
505
506 /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
507 #[arg(short, long)]
508 imatrix: Option<PathBuf>,
509
510 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
511 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
512 max_seq_len: usize,
513
514 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
515 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
516 max_batch_size: usize,
517
518 /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
519 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES)]
520 max_num_images: usize,
521
522 /// Maximum expected image size will have this edge length on both edges.
523 /// This affects automatic device mapping but is not a hard limit.
524 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH)]
525 max_image_length: usize,
526
527 /// Cache path for Hugging Face models downloaded locally
528 #[arg(long)]
529 hf_cache_path: Option<PathBuf>,
530 },
531
532 /// Select a diffusion plain model, without quantization or adapters
533 DiffusionPlain {
534 /// Model ID to load from. This may be a HF hub repo or a local path.
535 #[arg(short, long)]
536 model_id: String,
537
538 /// The architecture of the model.
539 #[arg(short, long, value_parser = parse_diffusion_arch)]
540 arch: DiffusionLoaderType,
541
542 /// Model data type. Defaults to `auto`.
543 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
544 dtype: ModelDType,
545 },
546
547 Speech {
548 /// Model ID to load from. This may be a HF hub repo or a local path.
549 #[arg(short, long)]
550 model_id: String,
551
552 /// DAC Model ID to load from. If not provided, this is automatically downloaded from the default path for the model.
553 /// This may be a HF hub repo or a local path.
554 #[arg(short, long)]
555 dac_model_id: Option<String>,
556
557 /// The architecture of the model.
558 #[arg(short, long, value_parser = parse_speech_arch)]
559 arch: SpeechLoaderType,
560
561 /// Model data type. Defaults to `auto`.
562 #[arg(long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
563 dtype: ModelDType,
564 },
565}