1use std::{
2 fs::{self, File},
3 num::NonZeroUsize,
4 path::PathBuf,
5 str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11 get_toml_selected_model_dtype,
12 pipeline::{
13 AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
14 GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
15 VisionLoaderBuilder, VisionSpecificConfig,
16 },
17 toml_selector::get_toml_selected_model_device_map_params,
18 AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
19 TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
20};
21
22pub struct LoaderBuilder {
24 model: ModelSelected,
25 no_kv_cache: bool,
26 chat_template: Option<String>,
27 jinja_explicit: Option<String>,
28 prompt_chunksize: Option<NonZeroUsize>,
29}
30
31impl LoaderBuilder {
32 pub fn new(model: ModelSelected) -> Self {
33 Self {
34 model,
35 no_kv_cache: false,
36 chat_template: None,
37 prompt_chunksize: None,
38 jinja_explicit: None,
39 }
40 }
41
42 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43 self.no_kv_cache = no_kv_cache;
44 self
45 }
46 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47 self.chat_template = chat_template;
48 self
49 }
50 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51 self.jinja_explicit = jinja_explicit;
52 self
53 }
54 pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
55 self.prompt_chunksize = prompt_chunksize;
56 self
57 }
58
59 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
60 loader_from_model_selected(self)
61 }
62}
63
64pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
65 match model {
66 ModelSelected::Plain { .. }
67 | ModelSelected::Run { .. }
68 | ModelSelected::Lora { .. }
69 | ModelSelected::GGUF { .. }
70 | ModelSelected::LoraGGUF { .. }
71 | ModelSelected::GGML { .. }
72 | ModelSelected::LoraGGML { .. }
73 | ModelSelected::Toml { .. }
74 | ModelSelected::VisionPlain { .. }
75 | ModelSelected::DiffusionPlain { .. }
76 | ModelSelected::Speech { .. } => None,
77 ModelSelected::XLora {
78 tgt_non_granular_index,
79 ..
80 }
81 | ModelSelected::XLoraGGUF {
82 tgt_non_granular_index,
83 ..
84 }
85 | ModelSelected::XLoraGGML {
86 tgt_non_granular_index,
87 ..
88 } => *tgt_non_granular_index,
89 ModelSelected::MultiModel { .. } => {
90 panic!("MultiModel variant should not be used in model loading functions")
91 }
92 }
93}
94
95pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
96 match model {
97 ModelSelected::Plain { dtype, .. }
98 | ModelSelected::Lora { dtype, .. }
99 | ModelSelected::XLora { dtype, .. }
100 | ModelSelected::VisionPlain { dtype, .. }
101 | ModelSelected::DiffusionPlain { dtype, .. }
102 | ModelSelected::GGML { dtype, .. }
103 | ModelSelected::GGUF { dtype, .. }
104 | ModelSelected::XLoraGGUF { dtype, .. }
105 | ModelSelected::XLoraGGML { dtype, .. }
106 | ModelSelected::LoraGGUF { dtype, .. }
107 | ModelSelected::LoraGGML { dtype, .. }
108 | ModelSelected::Run { dtype, .. }
109 | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
110 ModelSelected::Toml { file } => {
111 let selector: TomlSelector = toml::from_str(
112 &fs::read_to_string(file.clone())
113 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
114 )?;
115 Ok(get_toml_selected_model_dtype(&selector))
116 }
117 ModelSelected::MultiModel { .. } => {
118 anyhow::bail!("MultiModel variant should not be used in model loading functions")
119 }
120 }
121}
122
123pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
124 match model {
125 ModelSelected::Plain {
126 max_seq_len,
127 max_batch_size,
128 ..
129 }
130 | ModelSelected::Lora {
131 max_seq_len,
132 max_batch_size,
133 ..
134 }
135 | ModelSelected::XLora {
136 max_seq_len,
137 max_batch_size,
138 ..
139 }
140 | ModelSelected::GGML {
141 max_seq_len,
142 max_batch_size,
143 ..
144 }
145 | ModelSelected::GGUF {
146 max_seq_len,
147 max_batch_size,
148 ..
149 }
150 | ModelSelected::XLoraGGUF {
151 max_seq_len,
152 max_batch_size,
153 ..
154 }
155 | ModelSelected::XLoraGGML {
156 max_seq_len,
157 max_batch_size,
158 ..
159 }
160 | ModelSelected::LoraGGUF {
161 max_seq_len,
162 max_batch_size,
163 ..
164 }
165 | ModelSelected::LoraGGML {
166 max_seq_len,
167 max_batch_size,
168 ..
169 } => Ok(AutoDeviceMapParams::Text {
170 max_seq_len: *max_seq_len,
171 max_batch_size: *max_batch_size,
172 }),
173 ModelSelected::Run {
174 max_seq_len,
175 max_batch_size,
176 max_image_length,
177 max_num_images,
178 ..
179 } => {
180 if max_num_images.is_some() || max_image_length.is_some() {
181 let max_image_length =
182 max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
183 Ok(AutoDeviceMapParams::Vision {
184 max_seq_len: *max_seq_len,
185 max_batch_size: *max_batch_size,
186 max_image_shape: (max_image_length, max_image_length),
187 max_num_images: max_num_images
188 .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
189 })
190 } else {
191 Ok(AutoDeviceMapParams::Text {
192 max_seq_len: *max_seq_len,
193 max_batch_size: *max_batch_size,
194 })
195 }
196 }
197 ModelSelected::VisionPlain {
198 max_seq_len,
199 max_batch_size,
200 max_image_length,
201 max_num_images,
202 ..
203 } => Ok(AutoDeviceMapParams::Vision {
204 max_seq_len: *max_seq_len,
205 max_batch_size: *max_batch_size,
206 max_image_shape: (*max_image_length, *max_image_length),
207 max_num_images: *max_num_images,
208 }),
209 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
210 Ok(AutoDeviceMapParams::default_text())
211 }
212 ModelSelected::Toml { file } => {
213 let selector: TomlSelector = toml::from_str(
214 &fs::read_to_string(file.clone())
215 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
216 )?;
217 get_toml_selected_model_device_map_params(&selector)
218 }
219 ModelSelected::MultiModel { .. } => {
220 anyhow::bail!("MultiModel variant should not be used in model loading functions")
221 }
222 }
223}
224
225fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
226 let loader: Box<dyn Loader> = match args.model {
227 ModelSelected::Toml { file } => {
228 let selector: TomlSelector = toml::from_str(
229 &fs::read_to_string(file.clone())
230 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
231 )?;
232 let args = TomlLoaderArgs {
233 chat_template: args.chat_template,
234 no_kv_cache: args.no_kv_cache,
235 prompt_chunksize: args.prompt_chunksize,
236 jinja_explicit: args.jinja_explicit,
237 };
238 (selector, args).try_into()?
239 }
240 ModelSelected::Plain {
241 model_id,
242 tokenizer_json,
243 arch,
244 dtype: _,
245 topology,
246 organization,
247 write_uqff,
248 from_uqff,
249 imatrix,
250 calibration_file,
251 max_seq_len: _,
252 max_batch_size: _,
253 hf_cache_path,
254 matformer_config_path,
255 matformer_slice_name,
256 } => NormalLoaderBuilder::new(
257 NormalSpecificConfig {
258 prompt_chunksize: args.prompt_chunksize,
259 topology: Topology::from_option_path(topology)?,
260 organization: organization.unwrap_or_default(),
261 write_uqff,
262 from_uqff: from_uqff.map(|x| {
263 x.split(UQFF_MULTI_FILE_DELIMITER)
264 .map(PathBuf::from_str)
265 .map(|x| x.unwrap())
266 .collect::<Vec<_>>()
267 }),
268 imatrix,
269 calibration_file,
270 hf_cache_path,
271 matformer_config_path,
272 matformer_slice_name,
273 },
274 args.chat_template,
275 tokenizer_json,
276 Some(model_id),
277 args.no_kv_cache,
278 args.jinja_explicit,
279 )
280 .build(arch)?,
281 ModelSelected::Run {
282 model_id,
283 tokenizer_json,
284 dtype: _,
285 topology,
286 organization,
287 write_uqff,
288 from_uqff,
289 imatrix,
290 calibration_file,
291 max_edge,
292 max_seq_len: _,
293 max_batch_size: _,
294 max_num_images: _,
295 max_image_length: _,
296 hf_cache_path,
297 matformer_config_path,
298 matformer_slice_name,
299 } => {
300 let builder = AutoLoaderBuilder::new(
301 NormalSpecificConfig {
302 prompt_chunksize: args.prompt_chunksize,
303 topology: Topology::from_option_path(topology.clone())?,
304 organization: organization.unwrap_or_default(),
305 write_uqff: write_uqff.clone(),
306 from_uqff: from_uqff.clone().map(|x| {
307 x.split(UQFF_MULTI_FILE_DELIMITER)
308 .map(PathBuf::from_str)
309 .map(|x| x.unwrap())
310 .collect::<Vec<_>>()
311 }),
312 imatrix: imatrix.clone(),
313 calibration_file: calibration_file.clone(),
314 hf_cache_path: hf_cache_path.clone(),
315 matformer_config_path: matformer_config_path.clone(),
316 matformer_slice_name: matformer_slice_name.clone(),
317 },
318 VisionSpecificConfig {
319 prompt_chunksize: args.prompt_chunksize,
320 topology: Topology::from_option_path(topology)?,
321 write_uqff,
322 from_uqff: from_uqff.map(|x| {
323 x.split(UQFF_MULTI_FILE_DELIMITER)
324 .map(PathBuf::from_str)
325 .map(|x| x.unwrap())
326 .collect::<Vec<_>>()
327 }),
328 max_edge,
329 calibration_file,
330 imatrix,
331 hf_cache_path: hf_cache_path.clone(),
332 matformer_config_path,
333 matformer_slice_name,
334 },
335 args.chat_template,
336 tokenizer_json,
337 model_id,
338 args.no_kv_cache,
339 args.jinja_explicit,
340 );
341 let builder = if let Some(ref path) = hf_cache_path {
342 builder.hf_cache_path(path.clone())
343 } else {
344 builder
345 };
346 builder.build()
347 }
348 ModelSelected::VisionPlain {
349 model_id,
350 tokenizer_json,
351 arch,
352 dtype: _,
353 topology,
354 write_uqff,
355 from_uqff,
356 max_edge,
357 calibration_file,
358 max_seq_len: _,
359 max_batch_size: _,
360 max_num_images: _,
361 max_image_length: _,
362 hf_cache_path,
363 imatrix,
364 matformer_config_path,
365 matformer_slice_name,
366 } => VisionLoaderBuilder::new(
367 VisionSpecificConfig {
368 prompt_chunksize: args.prompt_chunksize,
369 topology: Topology::from_option_path(topology)?,
370 write_uqff,
371 from_uqff: from_uqff.map(|x| {
372 x.split(UQFF_MULTI_FILE_DELIMITER)
373 .map(PathBuf::from_str)
374 .map(|x| x.unwrap())
375 .collect::<Vec<_>>()
376 }),
377 max_edge,
378 calibration_file,
379 imatrix,
380 hf_cache_path,
381 matformer_config_path,
382 matformer_slice_name,
383 },
384 args.chat_template,
385 tokenizer_json,
386 Some(model_id),
387 args.jinja_explicit,
388 )
389 .build(arch),
390 ModelSelected::DiffusionPlain {
391 model_id,
392 arch,
393 dtype: _,
394 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
395 ModelSelected::Speech {
396 model_id,
397 dac_model_id,
398 arch,
399 ..
400 } => Box::new(SpeechLoader {
401 model_id,
402 dac_model_id,
403 arch,
404 cfg: None,
405 }),
406 ModelSelected::XLora {
407 model_id,
408 xlora_model_id,
409 order,
410 tokenizer_json,
411 tgt_non_granular_index,
412 arch,
413 dtype: _,
414 topology,
415 write_uqff,
416 from_uqff,
417 max_seq_len: _,
418 max_batch_size: _,
419 hf_cache_path,
420 } => NormalLoaderBuilder::new(
421 NormalSpecificConfig {
422 prompt_chunksize: args.prompt_chunksize,
423 topology: Topology::from_option_path(topology)?,
424 organization: Default::default(),
425 write_uqff,
426 from_uqff: from_uqff.map(|x| {
427 x.split(UQFF_MULTI_FILE_DELIMITER)
428 .map(PathBuf::from_str)
429 .map(|x| x.unwrap())
430 .collect::<Vec<_>>()
431 }),
432 imatrix: None,
433 calibration_file: None,
434 hf_cache_path,
435 matformer_config_path: None,
436 matformer_slice_name: None,
437 },
438 args.chat_template,
439 tokenizer_json,
440 model_id,
441 args.no_kv_cache,
442 args.jinja_explicit,
443 )
444 .with_xlora(
445 xlora_model_id,
446 serde_json::from_reader(
447 File::open(order.clone())
448 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
449 )?,
450 args.no_kv_cache,
451 tgt_non_granular_index,
452 )
453 .build(arch)?,
454 ModelSelected::Lora {
455 model_id,
456 tokenizer_json,
457 adapter_model_id,
458 arch,
459 dtype: _,
460 topology,
461 write_uqff,
462 from_uqff,
463 max_seq_len: _,
464 max_batch_size: _,
465 hf_cache_path,
466 } => NormalLoaderBuilder::new(
467 NormalSpecificConfig {
468 prompt_chunksize: args.prompt_chunksize,
469 topology: Topology::from_option_path(topology)?,
470 organization: Default::default(),
471 write_uqff,
472 from_uqff: from_uqff.map(|x| {
473 x.split(UQFF_MULTI_FILE_DELIMITER)
474 .map(PathBuf::from_str)
475 .map(|x| x.unwrap())
476 .collect::<Vec<_>>()
477 }),
478 imatrix: None,
479 calibration_file: None,
480 hf_cache_path,
481 matformer_config_path: None,
482 matformer_slice_name: None,
483 },
484 args.chat_template,
485 tokenizer_json,
486 model_id,
487 args.no_kv_cache,
488 args.jinja_explicit,
489 )
490 .with_lora(
491 adapter_model_id
492 .split(MULTI_LORA_DELIMITER)
493 .map(ToString::to_string)
494 .collect(),
495 )
496 .build(arch)?,
497 ModelSelected::GGUF {
498 tok_model_id,
499 quantized_model_id,
500 quantized_filename,
501 topology,
502 ..
503 } => GGUFLoaderBuilder::new(
504 args.chat_template,
505 tok_model_id,
506 quantized_model_id,
507 quantized_filename
508 .split(GGUF_MULTI_FILE_DELIMITER)
509 .map(ToOwned::to_owned)
510 .collect::<Vec<_>>(),
511 GGUFSpecificConfig {
512 prompt_chunksize: args.prompt_chunksize,
513 topology: Topology::from_option_path(topology)?,
514 },
515 args.no_kv_cache,
516 args.jinja_explicit,
517 )
518 .build(),
519 ModelSelected::XLoraGGUF {
520 tok_model_id,
521 quantized_model_id,
522 quantized_filename,
523 xlora_model_id,
524 order,
525 tgt_non_granular_index,
526 topology,
527 ..
528 } => GGUFLoaderBuilder::new(
529 args.chat_template,
530 tok_model_id,
531 quantized_model_id,
532 quantized_filename
533 .split(GGUF_MULTI_FILE_DELIMITER)
534 .map(ToOwned::to_owned)
535 .collect::<Vec<_>>(),
536 GGUFSpecificConfig {
537 prompt_chunksize: args.prompt_chunksize,
538 topology: Topology::from_option_path(topology)?,
539 },
540 args.no_kv_cache,
541 args.jinja_explicit,
542 )
543 .with_xlora(
544 xlora_model_id,
545 serde_json::from_reader(
546 File::open(order.clone())
547 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
548 )?,
549 args.no_kv_cache,
550 tgt_non_granular_index,
551 )
552 .build(),
553 ModelSelected::LoraGGUF {
554 tok_model_id,
555 quantized_model_id,
556 quantized_filename,
557 adapters_model_id,
558 order,
559 topology,
560 ..
561 } => GGUFLoaderBuilder::new(
562 args.chat_template,
563 tok_model_id,
564 quantized_model_id,
565 quantized_filename
566 .split(GGUF_MULTI_FILE_DELIMITER)
567 .map(ToOwned::to_owned)
568 .collect::<Vec<_>>(),
569 GGUFSpecificConfig {
570 prompt_chunksize: args.prompt_chunksize,
571 topology: Topology::from_option_path(topology)?,
572 },
573 args.no_kv_cache,
574 args.jinja_explicit,
575 )
576 .with_lora(
577 adapters_model_id,
578 serde_json::from_reader(
579 File::open(order.clone())
580 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
581 )?,
582 )
583 .build(),
584 ModelSelected::GGML {
585 tok_model_id,
586 tokenizer_json,
587 quantized_model_id,
588 quantized_filename,
589 gqa,
590 topology,
591 ..
592 } => GGMLLoaderBuilder::new(
593 GGMLSpecificConfig {
594 gqa,
595 prompt_chunksize: args.prompt_chunksize,
596 topology: Topology::from_option_path(topology)?,
597 },
598 args.chat_template,
599 tokenizer_json,
600 Some(tok_model_id),
601 quantized_model_id,
602 quantized_filename,
603 args.no_kv_cache,
604 args.jinja_explicit,
605 )
606 .build(),
607 ModelSelected::XLoraGGML {
608 tok_model_id,
609 tokenizer_json,
610 quantized_model_id,
611 quantized_filename,
612 xlora_model_id,
613 order,
614 tgt_non_granular_index,
615 gqa,
616 topology,
617 ..
618 } => GGMLLoaderBuilder::new(
619 GGMLSpecificConfig {
620 gqa,
621 prompt_chunksize: args.prompt_chunksize,
622 topology: Topology::from_option_path(topology)?,
623 },
624 args.chat_template,
625 tokenizer_json,
626 tok_model_id,
627 quantized_model_id,
628 quantized_filename,
629 args.no_kv_cache,
630 args.jinja_explicit,
631 )
632 .with_xlora(
633 xlora_model_id,
634 serde_json::from_reader(
635 File::open(order.clone())
636 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
637 )?,
638 args.no_kv_cache,
639 tgt_non_granular_index,
640 )
641 .build(),
642 ModelSelected::LoraGGML {
643 tok_model_id,
644 tokenizer_json,
645 quantized_model_id,
646 quantized_filename,
647 adapters_model_id,
648 order,
649 gqa,
650 topology,
651 ..
652 } => GGMLLoaderBuilder::new(
653 GGMLSpecificConfig {
654 gqa,
655 prompt_chunksize: args.prompt_chunksize,
656 topology: Topology::from_option_path(topology)?,
657 },
658 args.chat_template,
659 tokenizer_json,
660 tok_model_id,
661 quantized_model_id,
662 quantized_filename,
663 args.no_kv_cache,
664 args.jinja_explicit,
665 )
666 .with_lora(
667 adapters_model_id,
668 serde_json::from_reader(
669 File::open(order.clone())
670 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
671 )?,
672 )
673 .build(),
674 ModelSelected::MultiModel { .. } => {
675 anyhow::bail!("MultiModel variant should not be used in model loading functions")
676 }
677 };
678 Ok(loader)
679}