1use std::{
2 fs::{self, File},
3 num::NonZeroUsize,
4 path::PathBuf,
5 str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11 get_toml_selected_model_dtype,
12 pipeline::{
13 AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
14 GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
15 VisionLoaderBuilder, VisionSpecificConfig,
16 },
17 toml_selector::get_toml_selected_model_device_map_params,
18 AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
19 TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
20};
21
22pub struct LoaderBuilder {
24 model: ModelSelected,
25 no_kv_cache: bool,
26 chat_template: Option<String>,
27 jinja_explicit: Option<String>,
28 prompt_chunksize: Option<NonZeroUsize>,
29}
30
31impl LoaderBuilder {
32 pub fn new(model: ModelSelected) -> Self {
33 Self {
34 model,
35 no_kv_cache: false,
36 chat_template: None,
37 prompt_chunksize: None,
38 jinja_explicit: None,
39 }
40 }
41
42 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43 self.no_kv_cache = no_kv_cache;
44 self
45 }
46 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47 self.chat_template = chat_template;
48 self
49 }
50 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51 self.jinja_explicit = jinja_explicit;
52 self
53 }
54 pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
55 self.prompt_chunksize = prompt_chunksize;
56 self
57 }
58
59 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
60 loader_from_model_selected(self)
61 }
62}
63
64pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
65 match model {
66 ModelSelected::Plain { .. }
67 | ModelSelected::Run { .. }
68 | ModelSelected::Lora { .. }
69 | ModelSelected::GGUF { .. }
70 | ModelSelected::LoraGGUF { .. }
71 | ModelSelected::GGML { .. }
72 | ModelSelected::LoraGGML { .. }
73 | ModelSelected::Toml { .. }
74 | ModelSelected::VisionPlain { .. }
75 | ModelSelected::DiffusionPlain { .. }
76 | ModelSelected::Speech { .. } => None,
77 ModelSelected::XLora {
78 tgt_non_granular_index,
79 ..
80 }
81 | ModelSelected::XLoraGGUF {
82 tgt_non_granular_index,
83 ..
84 }
85 | ModelSelected::XLoraGGML {
86 tgt_non_granular_index,
87 ..
88 } => *tgt_non_granular_index,
89 ModelSelected::MultiModel { .. } => {
90 panic!("MultiModel variant should not be used in model loading functions")
91 }
92 }
93}
94
95pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
96 match model {
97 ModelSelected::Plain { dtype, .. }
98 | ModelSelected::Lora { dtype, .. }
99 | ModelSelected::XLora { dtype, .. }
100 | ModelSelected::VisionPlain { dtype, .. }
101 | ModelSelected::DiffusionPlain { dtype, .. }
102 | ModelSelected::GGML { dtype, .. }
103 | ModelSelected::GGUF { dtype, .. }
104 | ModelSelected::XLoraGGUF { dtype, .. }
105 | ModelSelected::XLoraGGML { dtype, .. }
106 | ModelSelected::LoraGGUF { dtype, .. }
107 | ModelSelected::LoraGGML { dtype, .. }
108 | ModelSelected::Run { dtype, .. }
109 | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
110 ModelSelected::Toml { file } => {
111 let selector: TomlSelector = toml::from_str(
112 &fs::read_to_string(file.clone())
113 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
114 )?;
115 Ok(get_toml_selected_model_dtype(&selector))
116 }
117 ModelSelected::MultiModel { .. } => {
118 anyhow::bail!("MultiModel variant should not be used in model loading functions")
119 }
120 }
121}
122
123pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
124 match model {
125 ModelSelected::Plain {
126 max_seq_len,
127 max_batch_size,
128 ..
129 }
130 | ModelSelected::Lora {
131 max_seq_len,
132 max_batch_size,
133 ..
134 }
135 | ModelSelected::XLora {
136 max_seq_len,
137 max_batch_size,
138 ..
139 }
140 | ModelSelected::GGML {
141 max_seq_len,
142 max_batch_size,
143 ..
144 }
145 | ModelSelected::GGUF {
146 max_seq_len,
147 max_batch_size,
148 ..
149 }
150 | ModelSelected::XLoraGGUF {
151 max_seq_len,
152 max_batch_size,
153 ..
154 }
155 | ModelSelected::XLoraGGML {
156 max_seq_len,
157 max_batch_size,
158 ..
159 }
160 | ModelSelected::LoraGGUF {
161 max_seq_len,
162 max_batch_size,
163 ..
164 }
165 | ModelSelected::LoraGGML {
166 max_seq_len,
167 max_batch_size,
168 ..
169 } => Ok(AutoDeviceMapParams::Text {
170 max_seq_len: *max_seq_len,
171 max_batch_size: *max_batch_size,
172 }),
173 ModelSelected::Run {
174 max_seq_len,
175 max_batch_size,
176 max_image_length,
177 max_num_images,
178 ..
179 } => {
180 if max_num_images.is_some() || max_image_length.is_some() {
181 let max_image_length =
182 max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
183 Ok(AutoDeviceMapParams::Vision {
184 max_seq_len: *max_seq_len,
185 max_batch_size: *max_batch_size,
186 max_image_shape: (max_image_length, max_image_length),
187 max_num_images: max_num_images
188 .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
189 })
190 } else {
191 Ok(AutoDeviceMapParams::Text {
192 max_seq_len: *max_seq_len,
193 max_batch_size: *max_batch_size,
194 })
195 }
196 }
197 ModelSelected::VisionPlain {
198 max_seq_len,
199 max_batch_size,
200 max_image_length,
201 max_num_images,
202 ..
203 } => Ok(AutoDeviceMapParams::Vision {
204 max_seq_len: *max_seq_len,
205 max_batch_size: *max_batch_size,
206 max_image_shape: (*max_image_length, *max_image_length),
207 max_num_images: *max_num_images,
208 }),
209 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
210 Ok(AutoDeviceMapParams::default_text())
211 }
212 ModelSelected::Toml { file } => {
213 let selector: TomlSelector = toml::from_str(
214 &fs::read_to_string(file.clone())
215 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
216 )?;
217 get_toml_selected_model_device_map_params(&selector)
218 }
219 ModelSelected::MultiModel { .. } => {
220 anyhow::bail!("MultiModel variant should not be used in model loading functions")
221 }
222 }
223}
224
225fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
226 let loader: Box<dyn Loader> = match args.model {
227 ModelSelected::Toml { file } => {
228 let selector: TomlSelector = toml::from_str(
229 &fs::read_to_string(file.clone())
230 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
231 )?;
232 let args = TomlLoaderArgs {
233 chat_template: args.chat_template,
234 no_kv_cache: args.no_kv_cache,
235 prompt_chunksize: args.prompt_chunksize,
236 jinja_explicit: args.jinja_explicit,
237 };
238 (selector, args).try_into()?
239 }
240 ModelSelected::Plain {
241 model_id,
242 tokenizer_json,
243 arch,
244 dtype: _,
245 topology,
246 organization,
247 write_uqff,
248 from_uqff,
249 imatrix,
250 calibration_file,
251 max_seq_len: _,
252 max_batch_size: _,
253 hf_cache_path,
254 } => NormalLoaderBuilder::new(
255 NormalSpecificConfig {
256 prompt_chunksize: args.prompt_chunksize,
257 topology: Topology::from_option_path(topology)?,
258 organization: organization.unwrap_or_default(),
259 write_uqff,
260 from_uqff: from_uqff.map(|x| {
261 x.split(UQFF_MULTI_FILE_DELIMITER)
262 .map(PathBuf::from_str)
263 .map(|x| x.unwrap())
264 .collect::<Vec<_>>()
265 }),
266 imatrix,
267 calibration_file,
268 hf_cache_path,
269 },
270 args.chat_template,
271 tokenizer_json,
272 Some(model_id),
273 args.no_kv_cache,
274 args.jinja_explicit,
275 )
276 .build(arch)?,
277 ModelSelected::Run {
278 model_id,
279 tokenizer_json,
280 dtype: _,
281 topology,
282 organization,
283 write_uqff,
284 from_uqff,
285 imatrix,
286 calibration_file,
287 max_edge,
288 max_seq_len: _,
289 max_batch_size: _,
290 max_num_images: _,
291 max_image_length: _,
292 hf_cache_path,
293 } => {
294 let builder = AutoLoaderBuilder::new(
295 NormalSpecificConfig {
296 prompt_chunksize: args.prompt_chunksize,
297 topology: Topology::from_option_path(topology.clone())?,
298 organization: organization.unwrap_or_default(),
299 write_uqff: write_uqff.clone(),
300 from_uqff: from_uqff.clone().map(|x| {
301 x.split(UQFF_MULTI_FILE_DELIMITER)
302 .map(PathBuf::from_str)
303 .map(|x| x.unwrap())
304 .collect::<Vec<_>>()
305 }),
306 imatrix: imatrix.clone(),
307 calibration_file: calibration_file.clone(),
308 hf_cache_path: hf_cache_path.clone(),
309 },
310 VisionSpecificConfig {
311 prompt_chunksize: args.prompt_chunksize,
312 topology: Topology::from_option_path(topology)?,
313 write_uqff,
314 from_uqff: from_uqff.map(|x| {
315 x.split(UQFF_MULTI_FILE_DELIMITER)
316 .map(PathBuf::from_str)
317 .map(|x| x.unwrap())
318 .collect::<Vec<_>>()
319 }),
320 max_edge,
321 calibration_file,
322 imatrix,
323 hf_cache_path: hf_cache_path.clone(),
324 },
325 args.chat_template,
326 tokenizer_json,
327 model_id,
328 args.no_kv_cache,
329 args.jinja_explicit,
330 );
331 let builder = if let Some(ref path) = hf_cache_path {
332 builder.hf_cache_path(path.clone())
333 } else {
334 builder
335 };
336 builder.build()
337 }
338 ModelSelected::VisionPlain {
339 model_id,
340 tokenizer_json,
341 arch,
342 dtype: _,
343 topology,
344 write_uqff,
345 from_uqff,
346 max_edge,
347 calibration_file,
348 max_seq_len: _,
349 max_batch_size: _,
350 max_num_images: _,
351 max_image_length: _,
352 hf_cache_path,
353 imatrix,
354 } => VisionLoaderBuilder::new(
355 VisionSpecificConfig {
356 prompt_chunksize: args.prompt_chunksize,
357 topology: Topology::from_option_path(topology)?,
358 write_uqff,
359 from_uqff: from_uqff.map(|x| {
360 x.split(UQFF_MULTI_FILE_DELIMITER)
361 .map(PathBuf::from_str)
362 .map(|x| x.unwrap())
363 .collect::<Vec<_>>()
364 }),
365 max_edge,
366 calibration_file,
367 imatrix,
368 hf_cache_path,
369 },
370 args.chat_template,
371 tokenizer_json,
372 Some(model_id),
373 args.jinja_explicit,
374 )
375 .build(arch),
376 ModelSelected::DiffusionPlain {
377 model_id,
378 arch,
379 dtype: _,
380 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
381 ModelSelected::Speech {
382 model_id,
383 dac_model_id,
384 arch,
385 ..
386 } => Box::new(SpeechLoader {
387 model_id,
388 dac_model_id,
389 arch,
390 cfg: None,
391 }),
392 ModelSelected::XLora {
393 model_id,
394 xlora_model_id,
395 order,
396 tokenizer_json,
397 tgt_non_granular_index,
398 arch,
399 dtype: _,
400 topology,
401 write_uqff,
402 from_uqff,
403 max_seq_len: _,
404 max_batch_size: _,
405 hf_cache_path,
406 } => NormalLoaderBuilder::new(
407 NormalSpecificConfig {
408 prompt_chunksize: args.prompt_chunksize,
409 topology: Topology::from_option_path(topology)?,
410 organization: Default::default(),
411 write_uqff,
412 from_uqff: from_uqff.map(|x| {
413 x.split(UQFF_MULTI_FILE_DELIMITER)
414 .map(PathBuf::from_str)
415 .map(|x| x.unwrap())
416 .collect::<Vec<_>>()
417 }),
418 imatrix: None,
419 calibration_file: None,
420 hf_cache_path,
421 },
422 args.chat_template,
423 tokenizer_json,
424 model_id,
425 args.no_kv_cache,
426 args.jinja_explicit,
427 )
428 .with_xlora(
429 xlora_model_id,
430 serde_json::from_reader(
431 File::open(order.clone())
432 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
433 )?,
434 args.no_kv_cache,
435 tgt_non_granular_index,
436 )
437 .build(arch)?,
438 ModelSelected::Lora {
439 model_id,
440 tokenizer_json,
441 adapter_model_id,
442 arch,
443 dtype: _,
444 topology,
445 write_uqff,
446 from_uqff,
447 max_seq_len: _,
448 max_batch_size: _,
449 hf_cache_path,
450 } => NormalLoaderBuilder::new(
451 NormalSpecificConfig {
452 prompt_chunksize: args.prompt_chunksize,
453 topology: Topology::from_option_path(topology)?,
454 organization: Default::default(),
455 write_uqff,
456 from_uqff: from_uqff.map(|x| {
457 x.split(UQFF_MULTI_FILE_DELIMITER)
458 .map(PathBuf::from_str)
459 .map(|x| x.unwrap())
460 .collect::<Vec<_>>()
461 }),
462 imatrix: None,
463 calibration_file: None,
464 hf_cache_path,
465 },
466 args.chat_template,
467 tokenizer_json,
468 model_id,
469 args.no_kv_cache,
470 args.jinja_explicit,
471 )
472 .with_lora(
473 adapter_model_id
474 .split(MULTI_LORA_DELIMITER)
475 .map(ToString::to_string)
476 .collect(),
477 )
478 .build(arch)?,
479 ModelSelected::GGUF {
480 tok_model_id,
481 quantized_model_id,
482 quantized_filename,
483 topology,
484 ..
485 } => GGUFLoaderBuilder::new(
486 args.chat_template,
487 tok_model_id,
488 quantized_model_id,
489 quantized_filename
490 .split(GGUF_MULTI_FILE_DELIMITER)
491 .map(ToOwned::to_owned)
492 .collect::<Vec<_>>(),
493 GGUFSpecificConfig {
494 prompt_chunksize: args.prompt_chunksize,
495 topology: Topology::from_option_path(topology)?,
496 },
497 args.no_kv_cache,
498 args.jinja_explicit,
499 )
500 .build(),
501 ModelSelected::XLoraGGUF {
502 tok_model_id,
503 quantized_model_id,
504 quantized_filename,
505 xlora_model_id,
506 order,
507 tgt_non_granular_index,
508 topology,
509 ..
510 } => GGUFLoaderBuilder::new(
511 args.chat_template,
512 tok_model_id,
513 quantized_model_id,
514 quantized_filename
515 .split(GGUF_MULTI_FILE_DELIMITER)
516 .map(ToOwned::to_owned)
517 .collect::<Vec<_>>(),
518 GGUFSpecificConfig {
519 prompt_chunksize: args.prompt_chunksize,
520 topology: Topology::from_option_path(topology)?,
521 },
522 args.no_kv_cache,
523 args.jinja_explicit,
524 )
525 .with_xlora(
526 xlora_model_id,
527 serde_json::from_reader(
528 File::open(order.clone())
529 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
530 )?,
531 args.no_kv_cache,
532 tgt_non_granular_index,
533 )
534 .build(),
535 ModelSelected::LoraGGUF {
536 tok_model_id,
537 quantized_model_id,
538 quantized_filename,
539 adapters_model_id,
540 order,
541 topology,
542 ..
543 } => GGUFLoaderBuilder::new(
544 args.chat_template,
545 tok_model_id,
546 quantized_model_id,
547 quantized_filename
548 .split(GGUF_MULTI_FILE_DELIMITER)
549 .map(ToOwned::to_owned)
550 .collect::<Vec<_>>(),
551 GGUFSpecificConfig {
552 prompt_chunksize: args.prompt_chunksize,
553 topology: Topology::from_option_path(topology)?,
554 },
555 args.no_kv_cache,
556 args.jinja_explicit,
557 )
558 .with_lora(
559 adapters_model_id,
560 serde_json::from_reader(
561 File::open(order.clone())
562 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
563 )?,
564 )
565 .build(),
566 ModelSelected::GGML {
567 tok_model_id,
568 tokenizer_json,
569 quantized_model_id,
570 quantized_filename,
571 gqa,
572 topology,
573 ..
574 } => GGMLLoaderBuilder::new(
575 GGMLSpecificConfig {
576 gqa,
577 prompt_chunksize: args.prompt_chunksize,
578 topology: Topology::from_option_path(topology)?,
579 },
580 args.chat_template,
581 tokenizer_json,
582 Some(tok_model_id),
583 quantized_model_id,
584 quantized_filename,
585 args.no_kv_cache,
586 args.jinja_explicit,
587 )
588 .build(),
589 ModelSelected::XLoraGGML {
590 tok_model_id,
591 tokenizer_json,
592 quantized_model_id,
593 quantized_filename,
594 xlora_model_id,
595 order,
596 tgt_non_granular_index,
597 gqa,
598 topology,
599 ..
600 } => GGMLLoaderBuilder::new(
601 GGMLSpecificConfig {
602 gqa,
603 prompt_chunksize: args.prompt_chunksize,
604 topology: Topology::from_option_path(topology)?,
605 },
606 args.chat_template,
607 tokenizer_json,
608 tok_model_id,
609 quantized_model_id,
610 quantized_filename,
611 args.no_kv_cache,
612 args.jinja_explicit,
613 )
614 .with_xlora(
615 xlora_model_id,
616 serde_json::from_reader(
617 File::open(order.clone())
618 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
619 )?,
620 args.no_kv_cache,
621 tgt_non_granular_index,
622 )
623 .build(),
624 ModelSelected::LoraGGML {
625 tok_model_id,
626 tokenizer_json,
627 quantized_model_id,
628 quantized_filename,
629 adapters_model_id,
630 order,
631 gqa,
632 topology,
633 ..
634 } => GGMLLoaderBuilder::new(
635 GGMLSpecificConfig {
636 gqa,
637 prompt_chunksize: args.prompt_chunksize,
638 topology: Topology::from_option_path(topology)?,
639 },
640 args.chat_template,
641 tokenizer_json,
642 tok_model_id,
643 quantized_model_id,
644 quantized_filename,
645 args.no_kv_cache,
646 args.jinja_explicit,
647 )
648 .with_lora(
649 adapters_model_id,
650 serde_json::from_reader(
651 File::open(order.clone())
652 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
653 )?,
654 )
655 .build(),
656 ModelSelected::MultiModel { .. } => {
657 anyhow::bail!("MultiModel variant should not be used in model loading functions")
658 }
659 };
660 Ok(loader)
661}