mistralrs_core/model_selected.rs
1use std::path::PathBuf;
2
3use clap::Subcommand;
4
5use crate::{
6 pipeline::{AutoDeviceMapParams, IsqOrganization, NormalLoaderType, VisionLoaderType},
7 DiffusionLoaderType, ModelDType,
8};
9
10fn parse_arch(x: &str) -> Result<NormalLoaderType, String> {
11 x.parse()
12}
13
14fn parse_vision_arch(x: &str) -> Result<VisionLoaderType, String> {
15 x.parse()
16}
17
18fn parse_diffusion_arch(x: &str) -> Result<DiffusionLoaderType, String> {
19 x.parse()
20}
21
22fn parse_model_dtype(x: &str) -> Result<ModelDType, String> {
23 x.parse()
24}
25
26#[derive(Debug, Subcommand)]
27pub enum ModelSelected {
28 /// Select the model from a toml file
29 Toml {
30 /// .toml file containing the selector configuration.
31 #[arg(short, long)]
32 file: String,
33 },
34
35 /// Select a plain model, without quantization or adapters
36 Plain {
37 /// Model ID to load from. This may be a HF hub repo or a local path.
38 #[arg(short, long)]
39 model_id: String,
40
41 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
42 #[arg(short, long)]
43 tokenizer_json: Option<String>,
44
45 /// The architecture of the model.
46 #[arg(short, long, value_parser = parse_arch)]
47 arch: Option<NormalLoaderType>,
48
49 /// Model data type. Defaults to `auto`.
50 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
51 dtype: ModelDType,
52
53 /// Path to a topology YAML file.
54 #[arg(long)]
55 topology: Option<String>,
56
57 #[allow(rustdoc::bare_urls)]
58 /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
59 #[arg(short, long)]
60 organization: Option<IsqOrganization>,
61
62 /// UQFF path to write to.
63 #[arg(short, long)]
64 write_uqff: Option<PathBuf>,
65
66 /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
67 #[arg(short, long)]
68 from_uqff: Option<PathBuf>,
69
70 /// .imatrix file to enhance GGUF quantizations with.
71 /// Incompatible with `--calibration-file/-c`
72 #[arg(short, long)]
73 imatrix: Option<PathBuf>,
74
75 /// Generate and utilize an imatrix to enhance GGUF quantizations.
76 /// Incompatible with `--imatrix/-i`
77 #[arg(short, long)]
78 calibration_file: Option<PathBuf>,
79
80 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
81 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
82 max_seq_len: usize,
83
84 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
85 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
86 max_batch_size: usize,
87
88 /// Cache path for Hugging Face models downloaded locally
89 #[arg(short, long)]
90 hf_cache_path: Option<PathBuf>,
91 },
92
93 /// Select an X-LoRA architecture
94 XLora {
95 /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
96 #[arg(short, long)]
97 model_id: Option<String>,
98
99 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
100 #[arg(short, long)]
101 tokenizer_json: Option<String>,
102
103 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
104 #[arg(short, long)]
105 xlora_model_id: String,
106
107 /// Ordering JSON file
108 #[arg(short, long)]
109 order: String,
110
111 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
112 /// This makes the maximum running sequences 1.
113 #[arg(long)]
114 tgt_non_granular_index: Option<usize>,
115
116 /// The architecture of the model.
117 #[arg(short, long, value_parser = parse_arch)]
118 arch: Option<NormalLoaderType>,
119
120 /// Model data type. Defaults to `auto`.
121 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
122 dtype: ModelDType,
123
124 /// Path to a topology YAML file.
125 #[arg(long)]
126 topology: Option<String>,
127
128 /// UQFF path to write to.
129 #[arg(short, long)]
130 write_uqff: Option<PathBuf>,
131
132 /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
133 #[arg(short, long)]
134 from_uqff: Option<PathBuf>,
135
136 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
137 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
138 max_seq_len: usize,
139
140 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
141 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
142 max_batch_size: usize,
143
144 /// Cache path for Hugging Face models downloaded locally
145 #[arg(short, long)]
146 hf_cache_path: Option<PathBuf>,
147 },
148
149 /// Select a LoRA architecture
150 Lora {
151 /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
152 #[arg(short, long)]
153 model_id: Option<String>,
154
155 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
156 #[arg(short, long)]
157 tokenizer_json: Option<String>,
158
159 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
160 #[arg(short, long)]
161 adapter_model_id: String,
162
163 /// The architecture of the model.
164 #[arg(long, value_parser = parse_arch)]
165 arch: Option<NormalLoaderType>,
166
167 /// Model data type. Defaults to `auto`.
168 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
169 dtype: ModelDType,
170
171 /// Path to a topology YAML file.
172 #[arg(long)]
173 topology: Option<String>,
174
175 /// UQFF path to write to.
176 #[arg(short, long)]
177 write_uqff: Option<PathBuf>,
178
179 /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
180 #[arg(short, long)]
181 from_uqff: Option<PathBuf>,
182
183 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
184 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
185 max_seq_len: usize,
186
187 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
188 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
189 max_batch_size: usize,
190
191 /// Cache path for Hugging Face models downloaded locally
192 #[arg(short, long)]
193 hf_cache_path: Option<PathBuf>,
194 },
195
196 /// Select a GGUF model.
197 GGUF {
198 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
199 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
200 /// removing all remote accesses.
201 #[arg(short, long)]
202 tok_model_id: Option<String>,
203
204 /// Quantized model ID to find the `quantized_filename`.
205 /// This may be a HF hub repo or a local path.
206 #[arg(short = 'm', long)]
207 quantized_model_id: String,
208
209 /// Quantized filename(s).
210 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
211 #[arg(short = 'f', long)]
212 quantized_filename: String,
213
214 /// Model data type. Defaults to `auto`.
215 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
216 dtype: ModelDType,
217
218 /// Path to a topology YAML file.
219 #[arg(long)]
220 topology: Option<String>,
221
222 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
223 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
224 max_seq_len: usize,
225
226 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
227 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
228 max_batch_size: usize,
229 },
230
231 /// Select a GGUF model with X-LoRA.
232 XLoraGGUF {
233 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
234 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
235 /// removing all remote accesses.
236 #[arg(short, long)]
237 tok_model_id: Option<String>,
238
239 /// Quantized model ID to find the `quantized_filename`.
240 /// This may be a HF hub repo or a local path.
241 #[arg(short = 'm', long)]
242 quantized_model_id: String,
243
244 /// Quantized filename(s).
245 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
246 #[arg(short = 'f', long)]
247 quantized_filename: String,
248
249 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
250 #[arg(short, long)]
251 xlora_model_id: String,
252
253 /// Ordering JSON file
254 #[arg(short, long)]
255 order: String,
256
257 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
258 /// This makes the maximum running sequences 1.
259 #[arg(long)]
260 tgt_non_granular_index: Option<usize>,
261
262 /// Model data type. Defaults to `auto`.
263 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
264 dtype: ModelDType,
265
266 /// Path to a topology YAML file.
267 #[arg(long)]
268 topology: Option<String>,
269
270 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
271 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
272 max_seq_len: usize,
273
274 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
275 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
276 max_batch_size: usize,
277 },
278
279 /// Select a GGUF model with LoRA.
280 LoraGGUF {
281 /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
282 /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
283 /// removing all remote accesses.
284 #[arg(short, long)]
285 tok_model_id: Option<String>,
286
287 /// Quantized model ID to find the `quantized_filename`.
288 /// This may be a HF hub repo or a local path.
289 #[arg(short = 'm', long)]
290 quantized_model_id: String,
291
292 /// Quantized filename(s).
293 /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
294 #[arg(short = 'f', long)]
295 quantized_filename: String,
296
297 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
298 #[arg(short, long)]
299 adapters_model_id: String,
300
301 /// Ordering JSON file
302 #[arg(short, long)]
303 order: String,
304
305 /// Model data type. Defaults to `auto`.
306 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
307 dtype: ModelDType,
308
309 /// Path to a topology YAML file.
310 #[arg(long)]
311 topology: Option<String>,
312
313 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
314 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
315 max_seq_len: usize,
316
317 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
318 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
319 max_batch_size: usize,
320 },
321
322 /// Select a GGML model.
323 GGML {
324 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
325 #[arg(short, long)]
326 tok_model_id: String,
327
328 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
329 #[arg(long)]
330 tokenizer_json: Option<String>,
331
332 /// Quantized model ID to find the `quantized_filename`.
333 /// This may be a HF hub repo or a local path.
334 #[arg(short = 'm', long)]
335 quantized_model_id: String,
336
337 /// Quantized filename.
338 #[arg(short = 'f', long)]
339 quantized_filename: String,
340
341 /// GQA value
342 #[arg(short, long, default_value_t = 1)]
343 gqa: usize,
344
345 /// Model data type. Defaults to `auto`.
346 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
347 dtype: ModelDType,
348
349 /// Path to a topology YAML file.
350 #[arg(long)]
351 topology: Option<String>,
352
353 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
354 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
355 max_seq_len: usize,
356
357 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
358 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
359 max_batch_size: usize,
360 },
361
362 /// Select a GGML model with X-LoRA.
363 XLoraGGML {
364 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
365 #[arg(short, long)]
366 tok_model_id: Option<String>,
367
368 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
369 #[arg(long)]
370 tokenizer_json: Option<String>,
371
372 /// Quantized model ID to find the `quantized_filename`.
373 /// This may be a HF hub repo or a local path.
374 #[arg(short = 'm', long)]
375 quantized_model_id: String,
376
377 /// Quantized filename.
378 #[arg(short = 'f', long)]
379 quantized_filename: String,
380
381 /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
382 #[arg(short, long)]
383 xlora_model_id: String,
384
385 /// Ordering JSON file
386 #[arg(short, long)]
387 order: String,
388
389 /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
390 /// This makes the maximum running sequences 1.
391 #[arg(long)]
392 tgt_non_granular_index: Option<usize>,
393
394 /// GQA value
395 #[arg(short, long, default_value_t = 1)]
396 gqa: usize,
397
398 /// Model data type. Defaults to `auto`.
399 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
400 dtype: ModelDType,
401
402 /// Path to a topology YAML file.
403 #[arg(long)]
404 topology: Option<String>,
405
406 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
407 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
408 max_seq_len: usize,
409
410 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
411 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
412 max_batch_size: usize,
413 },
414
415 /// Select a GGML model with LoRA.
416 LoraGGML {
417 /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
418 #[arg(short, long)]
419 tok_model_id: Option<String>,
420
421 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
422 #[arg(long)]
423 tokenizer_json: Option<String>,
424
425 /// Quantized model ID to find the `quantized_filename`.
426 /// This may be a HF hub repo or a local path.
427 #[arg(short = 'm', long)]
428 quantized_model_id: String,
429
430 /// Quantized filename.
431 #[arg(short = 'f', long)]
432 quantized_filename: String,
433
434 /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
435 #[arg(short, long)]
436 adapters_model_id: String,
437
438 /// Ordering JSON file
439 #[arg(short, long)]
440 order: String,
441
442 /// GQA value
443 #[arg(short, long, default_value_t = 1)]
444 gqa: usize,
445
446 /// Model data type. Defaults to `auto`.
447 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
448 dtype: ModelDType,
449
450 /// Path to a topology YAML file.
451 #[arg(long)]
452 topology: Option<String>,
453
454 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
455 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
456 max_seq_len: usize,
457
458 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
459 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
460 max_batch_size: usize,
461 },
462
463 /// Select a vision plain model, without quantization or adapters
464 VisionPlain {
465 /// Model ID to load from. This may be a HF hub repo or a local path.
466 #[arg(short, long)]
467 model_id: String,
468
469 /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
470 #[arg(short, long)]
471 tokenizer_json: Option<String>,
472
473 /// The architecture of the model.
474 #[arg(short, long, value_parser = parse_vision_arch)]
475 arch: VisionLoaderType,
476
477 /// Model data type. Defaults to `auto`.
478 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
479 dtype: ModelDType,
480
481 /// Path to a topology YAML file.
482 #[arg(long)]
483 topology: Option<String>,
484
485 /// UQFF path to write to.
486 #[arg(short, long)]
487 write_uqff: Option<PathBuf>,
488
489 /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
490 #[arg(short, long)]
491 from_uqff: Option<PathBuf>,
492
493 /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
494 /// This is only supported on the Qwen2-VL and Idefics models. Others handle this internally.
495 #[arg(short = 'e', long)]
496 max_edge: Option<u32>,
497
498 /// Generate and utilize an imatrix to enhance GGUF quantizations.
499 #[arg(short, long)]
500 calibration_file: Option<PathBuf>,
501
502 /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
503 #[arg(short, long)]
504 imatrix: Option<PathBuf>,
505
506 /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
507 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN)]
508 max_seq_len: usize,
509
510 /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
511 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE)]
512 max_batch_size: usize,
513
514 /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
515 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES)]
516 max_num_images: usize,
517
518 /// Maximum expected image size will have this edge length on both edges.
519 /// This affects automatic device mapping but is not a hard limit.
520 #[arg(long, default_value_t = AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH)]
521 max_image_length: usize,
522
523 /// Cache path for Hugging Face models downloaded locally
524 #[arg(short, long)]
525 hf_cache_path: Option<PathBuf>,
526 },
527
528 /// Select a diffusion plain model, without quantization or adapters
529 DiffusionPlain {
530 /// Model ID to load from. This may be a HF hub repo or a local path.
531 #[arg(short, long)]
532 model_id: String,
533
534 /// The architecture of the model.
535 #[arg(short, long, value_parser = parse_diffusion_arch)]
536 arch: DiffusionLoaderType,
537
538 /// Model data type. Defaults to `auto`.
539 #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
540 dtype: ModelDType,
541 },
542}