1use std::{
2 fs::{self, File},
3 path::PathBuf,
4 str::FromStr,
5};
6
7use mistralrs_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10 get_toml_selected_model_dtype,
11 pipeline::{
12 AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13 GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
14 VisionLoaderBuilder, VisionSpecificConfig,
15 },
16 toml_selector::get_toml_selected_model_device_map_params,
17 AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
18 TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
19};
20
21pub struct LoaderBuilder {
23 model: ModelSelected,
24 no_kv_cache: bool,
25 chat_template: Option<String>,
26 jinja_explicit: Option<String>,
27}
28
29impl LoaderBuilder {
30 pub fn new(model: ModelSelected) -> Self {
31 Self {
32 model,
33 no_kv_cache: false,
34 chat_template: None,
35 jinja_explicit: None,
36 }
37 }
38
39 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
40 self.no_kv_cache = no_kv_cache;
41 self
42 }
43 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
44 self.chat_template = chat_template;
45 self
46 }
47 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
48 self.jinja_explicit = jinja_explicit;
49 self
50 }
51
52 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
53 loader_from_model_selected(self)
54 }
55}
56
57pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
58 match model {
59 ModelSelected::Plain { .. }
60 | ModelSelected::Run { .. }
61 | ModelSelected::Lora { .. }
62 | ModelSelected::GGUF { .. }
63 | ModelSelected::LoraGGUF { .. }
64 | ModelSelected::GGML { .. }
65 | ModelSelected::LoraGGML { .. }
66 | ModelSelected::Toml { .. }
67 | ModelSelected::VisionPlain { .. }
68 | ModelSelected::DiffusionPlain { .. }
69 | ModelSelected::Speech { .. } => None,
70 ModelSelected::XLora {
71 tgt_non_granular_index,
72 ..
73 }
74 | ModelSelected::XLoraGGUF {
75 tgt_non_granular_index,
76 ..
77 }
78 | ModelSelected::XLoraGGML {
79 tgt_non_granular_index,
80 ..
81 } => *tgt_non_granular_index,
82 ModelSelected::MultiModel { .. } => {
83 panic!("MultiModel variant should not be used in model loading functions")
84 }
85 }
86}
87
88pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
89 match model {
90 ModelSelected::Plain { dtype, .. }
91 | ModelSelected::Lora { dtype, .. }
92 | ModelSelected::XLora { dtype, .. }
93 | ModelSelected::VisionPlain { dtype, .. }
94 | ModelSelected::DiffusionPlain { dtype, .. }
95 | ModelSelected::GGML { dtype, .. }
96 | ModelSelected::GGUF { dtype, .. }
97 | ModelSelected::XLoraGGUF { dtype, .. }
98 | ModelSelected::XLoraGGML { dtype, .. }
99 | ModelSelected::LoraGGUF { dtype, .. }
100 | ModelSelected::LoraGGML { dtype, .. }
101 | ModelSelected::Run { dtype, .. }
102 | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
103 ModelSelected::Toml { file } => {
104 let selector: TomlSelector = toml::from_str(
105 &fs::read_to_string(file.clone())
106 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
107 )?;
108 Ok(get_toml_selected_model_dtype(&selector))
109 }
110 ModelSelected::MultiModel { .. } => {
111 anyhow::bail!("MultiModel variant should not be used in model loading functions")
112 }
113 }
114}
115
116pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
117 match model {
118 ModelSelected::Plain {
119 max_seq_len,
120 max_batch_size,
121 ..
122 }
123 | ModelSelected::Lora {
124 max_seq_len,
125 max_batch_size,
126 ..
127 }
128 | ModelSelected::XLora {
129 max_seq_len,
130 max_batch_size,
131 ..
132 }
133 | ModelSelected::GGML {
134 max_seq_len,
135 max_batch_size,
136 ..
137 }
138 | ModelSelected::GGUF {
139 max_seq_len,
140 max_batch_size,
141 ..
142 }
143 | ModelSelected::XLoraGGUF {
144 max_seq_len,
145 max_batch_size,
146 ..
147 }
148 | ModelSelected::XLoraGGML {
149 max_seq_len,
150 max_batch_size,
151 ..
152 }
153 | ModelSelected::LoraGGUF {
154 max_seq_len,
155 max_batch_size,
156 ..
157 }
158 | ModelSelected::LoraGGML {
159 max_seq_len,
160 max_batch_size,
161 ..
162 } => Ok(AutoDeviceMapParams::Text {
163 max_seq_len: *max_seq_len,
164 max_batch_size: *max_batch_size,
165 }),
166 ModelSelected::Run {
167 max_seq_len,
168 max_batch_size,
169 max_image_length,
170 max_num_images,
171 ..
172 } => {
173 if max_num_images.is_some() || max_image_length.is_some() {
174 let max_image_length =
175 max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
176 Ok(AutoDeviceMapParams::Vision {
177 max_seq_len: *max_seq_len,
178 max_batch_size: *max_batch_size,
179 max_image_shape: (max_image_length, max_image_length),
180 max_num_images: max_num_images
181 .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
182 })
183 } else {
184 Ok(AutoDeviceMapParams::Text {
185 max_seq_len: *max_seq_len,
186 max_batch_size: *max_batch_size,
187 })
188 }
189 }
190 ModelSelected::VisionPlain {
191 max_seq_len,
192 max_batch_size,
193 max_image_length,
194 max_num_images,
195 ..
196 } => Ok(AutoDeviceMapParams::Vision {
197 max_seq_len: *max_seq_len,
198 max_batch_size: *max_batch_size,
199 max_image_shape: (*max_image_length, *max_image_length),
200 max_num_images: *max_num_images,
201 }),
202 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
203 Ok(AutoDeviceMapParams::default_text())
204 }
205 ModelSelected::Toml { file } => {
206 let selector: TomlSelector = toml::from_str(
207 &fs::read_to_string(file.clone())
208 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
209 )?;
210 get_toml_selected_model_device_map_params(&selector)
211 }
212 ModelSelected::MultiModel { .. } => {
213 anyhow::bail!("MultiModel variant should not be used in model loading functions")
214 }
215 }
216}
217
218fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
219 let loader: Box<dyn Loader> = match args.model {
220 ModelSelected::Toml { file } => {
221 let selector: TomlSelector = toml::from_str(
222 &fs::read_to_string(file.clone())
223 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
224 )?;
225 let args = TomlLoaderArgs {
226 chat_template: args.chat_template,
227 no_kv_cache: args.no_kv_cache,
228 jinja_explicit: args.jinja_explicit,
229 };
230 (selector, args).try_into()?
231 }
232 ModelSelected::Plain {
233 model_id,
234 tokenizer_json,
235 arch,
236 dtype: _,
237 topology,
238 organization,
239 write_uqff,
240 from_uqff,
241 imatrix,
242 calibration_file,
243 max_seq_len: _,
244 max_batch_size: _,
245 hf_cache_path,
246 matformer_config_path,
247 matformer_slice_name,
248 } => NormalLoaderBuilder::new(
249 NormalSpecificConfig {
250 topology: Topology::from_option_path(topology)?,
251 organization: organization.unwrap_or_default(),
252 write_uqff,
253 from_uqff: from_uqff.map(|x| {
254 x.split(UQFF_MULTI_FILE_DELIMITER)
255 .map(PathBuf::from_str)
256 .map(|x| x.unwrap())
257 .collect::<Vec<_>>()
258 }),
259 imatrix,
260 calibration_file,
261 hf_cache_path,
262 matformer_config_path,
263 matformer_slice_name,
264 },
265 args.chat_template,
266 tokenizer_json,
267 Some(model_id),
268 args.no_kv_cache,
269 args.jinja_explicit,
270 )
271 .build(arch)?,
272 ModelSelected::Run {
273 model_id,
274 tokenizer_json,
275 dtype: _,
276 topology,
277 organization,
278 write_uqff,
279 from_uqff,
280 imatrix,
281 calibration_file,
282 max_edge,
283 max_seq_len: _,
284 max_batch_size: _,
285 max_num_images: _,
286 max_image_length: _,
287 hf_cache_path,
288 matformer_config_path,
289 matformer_slice_name,
290 } => {
291 let builder = AutoLoaderBuilder::new(
292 NormalSpecificConfig {
293 topology: Topology::from_option_path(topology.clone())?,
294 organization: organization.unwrap_or_default(),
295 write_uqff: write_uqff.clone(),
296 from_uqff: from_uqff.clone().map(|x| {
297 x.split(UQFF_MULTI_FILE_DELIMITER)
298 .map(PathBuf::from_str)
299 .map(|x| x.unwrap())
300 .collect::<Vec<_>>()
301 }),
302 imatrix: imatrix.clone(),
303 calibration_file: calibration_file.clone(),
304 hf_cache_path: hf_cache_path.clone(),
305 matformer_config_path: matformer_config_path.clone(),
306 matformer_slice_name: matformer_slice_name.clone(),
307 },
308 VisionSpecificConfig {
309 topology: Topology::from_option_path(topology)?,
310 write_uqff,
311 from_uqff: from_uqff.map(|x| {
312 x.split(UQFF_MULTI_FILE_DELIMITER)
313 .map(PathBuf::from_str)
314 .map(|x| x.unwrap())
315 .collect::<Vec<_>>()
316 }),
317 max_edge,
318 calibration_file,
319 imatrix,
320 hf_cache_path: hf_cache_path.clone(),
321 matformer_config_path,
322 matformer_slice_name,
323 },
324 args.chat_template,
325 tokenizer_json,
326 model_id,
327 args.no_kv_cache,
328 args.jinja_explicit,
329 );
330 let builder = if let Some(ref path) = hf_cache_path {
331 builder.hf_cache_path(path.clone())
332 } else {
333 builder
334 };
335 builder.build()
336 }
337 ModelSelected::VisionPlain {
338 model_id,
339 tokenizer_json,
340 arch,
341 dtype: _,
342 topology,
343 write_uqff,
344 from_uqff,
345 max_edge,
346 calibration_file,
347 max_seq_len: _,
348 max_batch_size: _,
349 max_num_images: _,
350 max_image_length: _,
351 hf_cache_path,
352 imatrix,
353 matformer_config_path,
354 matformer_slice_name,
355 } => VisionLoaderBuilder::new(
356 VisionSpecificConfig {
357 topology: Topology::from_option_path(topology)?,
358 write_uqff,
359 from_uqff: from_uqff.map(|x| {
360 x.split(UQFF_MULTI_FILE_DELIMITER)
361 .map(PathBuf::from_str)
362 .map(|x| x.unwrap())
363 .collect::<Vec<_>>()
364 }),
365 max_edge,
366 calibration_file,
367 imatrix,
368 hf_cache_path,
369 matformer_config_path,
370 matformer_slice_name,
371 },
372 args.chat_template,
373 tokenizer_json,
374 Some(model_id),
375 args.jinja_explicit,
376 )
377 .build(arch),
378 ModelSelected::DiffusionPlain {
379 model_id,
380 arch,
381 dtype: _,
382 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
383 ModelSelected::Speech {
384 model_id,
385 dac_model_id,
386 arch,
387 ..
388 } => Box::new(SpeechLoader {
389 model_id,
390 dac_model_id,
391 arch,
392 cfg: None,
393 }),
394 ModelSelected::XLora {
395 model_id,
396 xlora_model_id,
397 order,
398 tokenizer_json,
399 tgt_non_granular_index,
400 arch,
401 dtype: _,
402 topology,
403 write_uqff,
404 from_uqff,
405 max_seq_len: _,
406 max_batch_size: _,
407 hf_cache_path,
408 } => NormalLoaderBuilder::new(
409 NormalSpecificConfig {
410 topology: Topology::from_option_path(topology)?,
411 organization: Default::default(),
412 write_uqff,
413 from_uqff: from_uqff.map(|x| {
414 x.split(UQFF_MULTI_FILE_DELIMITER)
415 .map(PathBuf::from_str)
416 .map(|x| x.unwrap())
417 .collect::<Vec<_>>()
418 }),
419 imatrix: None,
420 calibration_file: None,
421 hf_cache_path,
422 matformer_config_path: None,
423 matformer_slice_name: None,
424 },
425 args.chat_template,
426 tokenizer_json,
427 model_id,
428 args.no_kv_cache,
429 args.jinja_explicit,
430 )
431 .with_xlora(
432 xlora_model_id,
433 serde_json::from_reader(
434 File::open(order.clone())
435 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
436 )?,
437 args.no_kv_cache,
438 tgt_non_granular_index,
439 )
440 .build(arch)?,
441 ModelSelected::Lora {
442 model_id,
443 tokenizer_json,
444 adapter_model_id,
445 arch,
446 dtype: _,
447 topology,
448 write_uqff,
449 from_uqff,
450 max_seq_len: _,
451 max_batch_size: _,
452 hf_cache_path,
453 } => NormalLoaderBuilder::new(
454 NormalSpecificConfig {
455 topology: Topology::from_option_path(topology)?,
456 organization: Default::default(),
457 write_uqff,
458 from_uqff: from_uqff.map(|x| {
459 x.split(UQFF_MULTI_FILE_DELIMITER)
460 .map(PathBuf::from_str)
461 .map(|x| x.unwrap())
462 .collect::<Vec<_>>()
463 }),
464 imatrix: None,
465 calibration_file: None,
466 hf_cache_path,
467 matformer_config_path: None,
468 matformer_slice_name: None,
469 },
470 args.chat_template,
471 tokenizer_json,
472 model_id,
473 args.no_kv_cache,
474 args.jinja_explicit,
475 )
476 .with_lora(
477 adapter_model_id
478 .split(MULTI_LORA_DELIMITER)
479 .map(ToString::to_string)
480 .collect(),
481 )
482 .build(arch)?,
483 ModelSelected::GGUF {
484 tok_model_id,
485 quantized_model_id,
486 quantized_filename,
487 topology,
488 ..
489 } => GGUFLoaderBuilder::new(
490 args.chat_template,
491 tok_model_id,
492 quantized_model_id,
493 quantized_filename
494 .split(GGUF_MULTI_FILE_DELIMITER)
495 .map(ToOwned::to_owned)
496 .collect::<Vec<_>>(),
497 GGUFSpecificConfig {
498 topology: Topology::from_option_path(topology)?,
499 },
500 args.no_kv_cache,
501 args.jinja_explicit,
502 )
503 .build(),
504 ModelSelected::XLoraGGUF {
505 tok_model_id,
506 quantized_model_id,
507 quantized_filename,
508 xlora_model_id,
509 order,
510 tgt_non_granular_index,
511 topology,
512 ..
513 } => GGUFLoaderBuilder::new(
514 args.chat_template,
515 tok_model_id,
516 quantized_model_id,
517 quantized_filename
518 .split(GGUF_MULTI_FILE_DELIMITER)
519 .map(ToOwned::to_owned)
520 .collect::<Vec<_>>(),
521 GGUFSpecificConfig {
522 topology: Topology::from_option_path(topology)?,
523 },
524 args.no_kv_cache,
525 args.jinja_explicit,
526 )
527 .with_xlora(
528 xlora_model_id,
529 serde_json::from_reader(
530 File::open(order.clone())
531 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
532 )?,
533 args.no_kv_cache,
534 tgt_non_granular_index,
535 )
536 .build(),
537 ModelSelected::LoraGGUF {
538 tok_model_id,
539 quantized_model_id,
540 quantized_filename,
541 adapters_model_id,
542 order,
543 topology,
544 ..
545 } => GGUFLoaderBuilder::new(
546 args.chat_template,
547 tok_model_id,
548 quantized_model_id,
549 quantized_filename
550 .split(GGUF_MULTI_FILE_DELIMITER)
551 .map(ToOwned::to_owned)
552 .collect::<Vec<_>>(),
553 GGUFSpecificConfig {
554 topology: Topology::from_option_path(topology)?,
555 },
556 args.no_kv_cache,
557 args.jinja_explicit,
558 )
559 .with_lora(
560 adapters_model_id,
561 serde_json::from_reader(
562 File::open(order.clone())
563 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
564 )?,
565 )
566 .build(),
567 ModelSelected::GGML {
568 tok_model_id,
569 tokenizer_json,
570 quantized_model_id,
571 quantized_filename,
572 gqa,
573 topology,
574 ..
575 } => GGMLLoaderBuilder::new(
576 GGMLSpecificConfig {
577 gqa,
578 topology: Topology::from_option_path(topology)?,
579 },
580 args.chat_template,
581 tokenizer_json,
582 Some(tok_model_id),
583 quantized_model_id,
584 quantized_filename,
585 args.no_kv_cache,
586 args.jinja_explicit,
587 )
588 .build(),
589 ModelSelected::XLoraGGML {
590 tok_model_id,
591 tokenizer_json,
592 quantized_model_id,
593 quantized_filename,
594 xlora_model_id,
595 order,
596 tgt_non_granular_index,
597 gqa,
598 topology,
599 ..
600 } => GGMLLoaderBuilder::new(
601 GGMLSpecificConfig {
602 gqa,
603 topology: Topology::from_option_path(topology)?,
604 },
605 args.chat_template,
606 tokenizer_json,
607 tok_model_id,
608 quantized_model_id,
609 quantized_filename,
610 args.no_kv_cache,
611 args.jinja_explicit,
612 )
613 .with_xlora(
614 xlora_model_id,
615 serde_json::from_reader(
616 File::open(order.clone())
617 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
618 )?,
619 args.no_kv_cache,
620 tgt_non_granular_index,
621 )
622 .build(),
623 ModelSelected::LoraGGML {
624 tok_model_id,
625 tokenizer_json,
626 quantized_model_id,
627 quantized_filename,
628 adapters_model_id,
629 order,
630 gqa,
631 topology,
632 ..
633 } => GGMLLoaderBuilder::new(
634 GGMLSpecificConfig {
635 gqa,
636 topology: Topology::from_option_path(topology)?,
637 },
638 args.chat_template,
639 tokenizer_json,
640 tok_model_id,
641 quantized_model_id,
642 quantized_filename,
643 args.no_kv_cache,
644 args.jinja_explicit,
645 )
646 .with_lora(
647 adapters_model_id,
648 serde_json::from_reader(
649 File::open(order.clone())
650 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
651 )?,
652 )
653 .build(),
654 ModelSelected::MultiModel { .. } => {
655 anyhow::bail!("MultiModel variant should not be used in model loading functions")
656 }
657 };
658 Ok(loader)
659}