mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    num::NonZeroUsize,
4    path::PathBuf,
5    str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11    get_toml_selected_model_dtype,
12    pipeline::{
13        AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
14        GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
15        VisionLoaderBuilder, VisionSpecificConfig,
16    },
17    toml_selector::get_toml_selected_model_device_map_params,
18    AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
19    TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
20};
21
22/// A builder for a loader using the selected model.
23pub struct LoaderBuilder {
24    model: ModelSelected,
25    no_kv_cache: bool,
26    chat_template: Option<String>,
27    jinja_explicit: Option<String>,
28    prompt_chunksize: Option<NonZeroUsize>,
29}
30
31impl LoaderBuilder {
32    pub fn new(model: ModelSelected) -> Self {
33        Self {
34            model,
35            no_kv_cache: false,
36            chat_template: None,
37            prompt_chunksize: None,
38            jinja_explicit: None,
39        }
40    }
41
42    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43        self.no_kv_cache = no_kv_cache;
44        self
45    }
46    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47        self.chat_template = chat_template;
48        self
49    }
50    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51        self.jinja_explicit = jinja_explicit;
52        self
53    }
54    pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
55        self.prompt_chunksize = prompt_chunksize;
56        self
57    }
58
59    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
60        loader_from_model_selected(self)
61    }
62}
63
64pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
65    match model {
66        ModelSelected::Plain { .. }
67        | ModelSelected::Run { .. }
68        | ModelSelected::Lora { .. }
69        | ModelSelected::GGUF { .. }
70        | ModelSelected::LoraGGUF { .. }
71        | ModelSelected::GGML { .. }
72        | ModelSelected::LoraGGML { .. }
73        | ModelSelected::Toml { .. }
74        | ModelSelected::VisionPlain { .. }
75        | ModelSelected::DiffusionPlain { .. }
76        | ModelSelected::Speech { .. } => None,
77        ModelSelected::XLora {
78            tgt_non_granular_index,
79            ..
80        }
81        | ModelSelected::XLoraGGUF {
82            tgt_non_granular_index,
83            ..
84        }
85        | ModelSelected::XLoraGGML {
86            tgt_non_granular_index,
87            ..
88        } => *tgt_non_granular_index,
89        ModelSelected::MultiModel { .. } => {
90            panic!("MultiModel variant should not be used in model loading functions")
91        }
92    }
93}
94
95pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
96    match model {
97        ModelSelected::Plain { dtype, .. }
98        | ModelSelected::Lora { dtype, .. }
99        | ModelSelected::XLora { dtype, .. }
100        | ModelSelected::VisionPlain { dtype, .. }
101        | ModelSelected::DiffusionPlain { dtype, .. }
102        | ModelSelected::GGML { dtype, .. }
103        | ModelSelected::GGUF { dtype, .. }
104        | ModelSelected::XLoraGGUF { dtype, .. }
105        | ModelSelected::XLoraGGML { dtype, .. }
106        | ModelSelected::LoraGGUF { dtype, .. }
107        | ModelSelected::LoraGGML { dtype, .. }
108        | ModelSelected::Run { dtype, .. }
109        | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
110        ModelSelected::Toml { file } => {
111            let selector: TomlSelector = toml::from_str(
112                &fs::read_to_string(file.clone())
113                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
114            )?;
115            Ok(get_toml_selected_model_dtype(&selector))
116        }
117        ModelSelected::MultiModel { .. } => {
118            anyhow::bail!("MultiModel variant should not be used in model loading functions")
119        }
120    }
121}
122
123pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
124    match model {
125        ModelSelected::Plain {
126            max_seq_len,
127            max_batch_size,
128            ..
129        }
130        | ModelSelected::Lora {
131            max_seq_len,
132            max_batch_size,
133            ..
134        }
135        | ModelSelected::XLora {
136            max_seq_len,
137            max_batch_size,
138            ..
139        }
140        | ModelSelected::GGML {
141            max_seq_len,
142            max_batch_size,
143            ..
144        }
145        | ModelSelected::GGUF {
146            max_seq_len,
147            max_batch_size,
148            ..
149        }
150        | ModelSelected::XLoraGGUF {
151            max_seq_len,
152            max_batch_size,
153            ..
154        }
155        | ModelSelected::XLoraGGML {
156            max_seq_len,
157            max_batch_size,
158            ..
159        }
160        | ModelSelected::LoraGGUF {
161            max_seq_len,
162            max_batch_size,
163            ..
164        }
165        | ModelSelected::LoraGGML {
166            max_seq_len,
167            max_batch_size,
168            ..
169        } => Ok(AutoDeviceMapParams::Text {
170            max_seq_len: *max_seq_len,
171            max_batch_size: *max_batch_size,
172        }),
173        ModelSelected::Run {
174            max_seq_len,
175            max_batch_size,
176            max_image_length,
177            max_num_images,
178            ..
179        } => {
180            if max_num_images.is_some() || max_image_length.is_some() {
181                let max_image_length =
182                    max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
183                Ok(AutoDeviceMapParams::Vision {
184                    max_seq_len: *max_seq_len,
185                    max_batch_size: *max_batch_size,
186                    max_image_shape: (max_image_length, max_image_length),
187                    max_num_images: max_num_images
188                        .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
189                })
190            } else {
191                Ok(AutoDeviceMapParams::Text {
192                    max_seq_len: *max_seq_len,
193                    max_batch_size: *max_batch_size,
194                })
195            }
196        }
197        ModelSelected::VisionPlain {
198            max_seq_len,
199            max_batch_size,
200            max_image_length,
201            max_num_images,
202            ..
203        } => Ok(AutoDeviceMapParams::Vision {
204            max_seq_len: *max_seq_len,
205            max_batch_size: *max_batch_size,
206            max_image_shape: (*max_image_length, *max_image_length),
207            max_num_images: *max_num_images,
208        }),
209        ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
210            Ok(AutoDeviceMapParams::default_text())
211        }
212        ModelSelected::Toml { file } => {
213            let selector: TomlSelector = toml::from_str(
214                &fs::read_to_string(file.clone())
215                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
216            )?;
217            get_toml_selected_model_device_map_params(&selector)
218        }
219        ModelSelected::MultiModel { .. } => {
220            anyhow::bail!("MultiModel variant should not be used in model loading functions")
221        }
222    }
223}
224
225fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
226    let loader: Box<dyn Loader> = match args.model {
227        ModelSelected::Toml { file } => {
228            let selector: TomlSelector = toml::from_str(
229                &fs::read_to_string(file.clone())
230                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
231            )?;
232            let args = TomlLoaderArgs {
233                chat_template: args.chat_template,
234                no_kv_cache: args.no_kv_cache,
235                prompt_chunksize: args.prompt_chunksize,
236                jinja_explicit: args.jinja_explicit,
237            };
238            (selector, args).try_into()?
239        }
240        ModelSelected::Plain {
241            model_id,
242            tokenizer_json,
243            arch,
244            dtype: _,
245            topology,
246            organization,
247            write_uqff,
248            from_uqff,
249            imatrix,
250            calibration_file,
251            max_seq_len: _,
252            max_batch_size: _,
253            hf_cache_path,
254        } => NormalLoaderBuilder::new(
255            NormalSpecificConfig {
256                prompt_chunksize: args.prompt_chunksize,
257                topology: Topology::from_option_path(topology)?,
258                organization: organization.unwrap_or_default(),
259                write_uqff,
260                from_uqff: from_uqff.map(|x| {
261                    x.split(UQFF_MULTI_FILE_DELIMITER)
262                        .map(PathBuf::from_str)
263                        .map(|x| x.unwrap())
264                        .collect::<Vec<_>>()
265                }),
266                imatrix,
267                calibration_file,
268                hf_cache_path,
269            },
270            args.chat_template,
271            tokenizer_json,
272            Some(model_id),
273            args.no_kv_cache,
274            args.jinja_explicit,
275        )
276        .build(arch)?,
277        ModelSelected::Run {
278            model_id,
279            tokenizer_json,
280            dtype: _,
281            topology,
282            organization,
283            write_uqff,
284            from_uqff,
285            imatrix,
286            calibration_file,
287            max_edge,
288            max_seq_len: _,
289            max_batch_size: _,
290            max_num_images: _,
291            max_image_length: _,
292            hf_cache_path,
293        } => {
294            let builder = AutoLoaderBuilder::new(
295                NormalSpecificConfig {
296                    prompt_chunksize: args.prompt_chunksize,
297                    topology: Topology::from_option_path(topology.clone())?,
298                    organization: organization.unwrap_or_default(),
299                    write_uqff: write_uqff.clone(),
300                    from_uqff: from_uqff.clone().map(|x| {
301                        x.split(UQFF_MULTI_FILE_DELIMITER)
302                            .map(PathBuf::from_str)
303                            .map(|x| x.unwrap())
304                            .collect::<Vec<_>>()
305                    }),
306                    imatrix: imatrix.clone(),
307                    calibration_file: calibration_file.clone(),
308                    hf_cache_path: hf_cache_path.clone(),
309                },
310                VisionSpecificConfig {
311                    prompt_chunksize: args.prompt_chunksize,
312                    topology: Topology::from_option_path(topology)?,
313                    write_uqff,
314                    from_uqff: from_uqff.map(|x| {
315                        x.split(UQFF_MULTI_FILE_DELIMITER)
316                            .map(PathBuf::from_str)
317                            .map(|x| x.unwrap())
318                            .collect::<Vec<_>>()
319                    }),
320                    max_edge,
321                    calibration_file,
322                    imatrix,
323                    hf_cache_path: hf_cache_path.clone(),
324                },
325                args.chat_template,
326                tokenizer_json,
327                model_id,
328                args.no_kv_cache,
329                args.jinja_explicit,
330            );
331            let builder = if let Some(ref path) = hf_cache_path {
332                builder.hf_cache_path(path.clone())
333            } else {
334                builder
335            };
336            builder.build()
337        }
338        ModelSelected::VisionPlain {
339            model_id,
340            tokenizer_json,
341            arch,
342            dtype: _,
343            topology,
344            write_uqff,
345            from_uqff,
346            max_edge,
347            calibration_file,
348            max_seq_len: _,
349            max_batch_size: _,
350            max_num_images: _,
351            max_image_length: _,
352            hf_cache_path,
353            imatrix,
354        } => VisionLoaderBuilder::new(
355            VisionSpecificConfig {
356                prompt_chunksize: args.prompt_chunksize,
357                topology: Topology::from_option_path(topology)?,
358                write_uqff,
359                from_uqff: from_uqff.map(|x| {
360                    x.split(UQFF_MULTI_FILE_DELIMITER)
361                        .map(PathBuf::from_str)
362                        .map(|x| x.unwrap())
363                        .collect::<Vec<_>>()
364                }),
365                max_edge,
366                calibration_file,
367                imatrix,
368                hf_cache_path,
369            },
370            args.chat_template,
371            tokenizer_json,
372            Some(model_id),
373            args.jinja_explicit,
374        )
375        .build(arch),
376        ModelSelected::DiffusionPlain {
377            model_id,
378            arch,
379            dtype: _,
380        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
381        ModelSelected::Speech {
382            model_id,
383            dac_model_id,
384            arch,
385            ..
386        } => Box::new(SpeechLoader {
387            model_id,
388            dac_model_id,
389            arch,
390            cfg: None,
391        }),
392        ModelSelected::XLora {
393            model_id,
394            xlora_model_id,
395            order,
396            tokenizer_json,
397            tgt_non_granular_index,
398            arch,
399            dtype: _,
400            topology,
401            write_uqff,
402            from_uqff,
403            max_seq_len: _,
404            max_batch_size: _,
405            hf_cache_path,
406        } => NormalLoaderBuilder::new(
407            NormalSpecificConfig {
408                prompt_chunksize: args.prompt_chunksize,
409                topology: Topology::from_option_path(topology)?,
410                organization: Default::default(),
411                write_uqff,
412                from_uqff: from_uqff.map(|x| {
413                    x.split(UQFF_MULTI_FILE_DELIMITER)
414                        .map(PathBuf::from_str)
415                        .map(|x| x.unwrap())
416                        .collect::<Vec<_>>()
417                }),
418                imatrix: None,
419                calibration_file: None,
420                hf_cache_path,
421            },
422            args.chat_template,
423            tokenizer_json,
424            model_id,
425            args.no_kv_cache,
426            args.jinja_explicit,
427        )
428        .with_xlora(
429            xlora_model_id,
430            serde_json::from_reader(
431                File::open(order.clone())
432                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
433            )?,
434            args.no_kv_cache,
435            tgt_non_granular_index,
436        )
437        .build(arch)?,
438        ModelSelected::Lora {
439            model_id,
440            tokenizer_json,
441            adapter_model_id,
442            arch,
443            dtype: _,
444            topology,
445            write_uqff,
446            from_uqff,
447            max_seq_len: _,
448            max_batch_size: _,
449            hf_cache_path,
450        } => NormalLoaderBuilder::new(
451            NormalSpecificConfig {
452                prompt_chunksize: args.prompt_chunksize,
453                topology: Topology::from_option_path(topology)?,
454                organization: Default::default(),
455                write_uqff,
456                from_uqff: from_uqff.map(|x| {
457                    x.split(UQFF_MULTI_FILE_DELIMITER)
458                        .map(PathBuf::from_str)
459                        .map(|x| x.unwrap())
460                        .collect::<Vec<_>>()
461                }),
462                imatrix: None,
463                calibration_file: None,
464                hf_cache_path,
465            },
466            args.chat_template,
467            tokenizer_json,
468            model_id,
469            args.no_kv_cache,
470            args.jinja_explicit,
471        )
472        .with_lora(
473            adapter_model_id
474                .split(MULTI_LORA_DELIMITER)
475                .map(ToString::to_string)
476                .collect(),
477        )
478        .build(arch)?,
479        ModelSelected::GGUF {
480            tok_model_id,
481            quantized_model_id,
482            quantized_filename,
483            topology,
484            ..
485        } => GGUFLoaderBuilder::new(
486            args.chat_template,
487            tok_model_id,
488            quantized_model_id,
489            quantized_filename
490                .split(GGUF_MULTI_FILE_DELIMITER)
491                .map(ToOwned::to_owned)
492                .collect::<Vec<_>>(),
493            GGUFSpecificConfig {
494                prompt_chunksize: args.prompt_chunksize,
495                topology: Topology::from_option_path(topology)?,
496            },
497            args.no_kv_cache,
498            args.jinja_explicit,
499        )
500        .build(),
501        ModelSelected::XLoraGGUF {
502            tok_model_id,
503            quantized_model_id,
504            quantized_filename,
505            xlora_model_id,
506            order,
507            tgt_non_granular_index,
508            topology,
509            ..
510        } => GGUFLoaderBuilder::new(
511            args.chat_template,
512            tok_model_id,
513            quantized_model_id,
514            quantized_filename
515                .split(GGUF_MULTI_FILE_DELIMITER)
516                .map(ToOwned::to_owned)
517                .collect::<Vec<_>>(),
518            GGUFSpecificConfig {
519                prompt_chunksize: args.prompt_chunksize,
520                topology: Topology::from_option_path(topology)?,
521            },
522            args.no_kv_cache,
523            args.jinja_explicit,
524        )
525        .with_xlora(
526            xlora_model_id,
527            serde_json::from_reader(
528                File::open(order.clone())
529                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
530            )?,
531            args.no_kv_cache,
532            tgt_non_granular_index,
533        )
534        .build(),
535        ModelSelected::LoraGGUF {
536            tok_model_id,
537            quantized_model_id,
538            quantized_filename,
539            adapters_model_id,
540            order,
541            topology,
542            ..
543        } => GGUFLoaderBuilder::new(
544            args.chat_template,
545            tok_model_id,
546            quantized_model_id,
547            quantized_filename
548                .split(GGUF_MULTI_FILE_DELIMITER)
549                .map(ToOwned::to_owned)
550                .collect::<Vec<_>>(),
551            GGUFSpecificConfig {
552                prompt_chunksize: args.prompt_chunksize,
553                topology: Topology::from_option_path(topology)?,
554            },
555            args.no_kv_cache,
556            args.jinja_explicit,
557        )
558        .with_lora(
559            adapters_model_id,
560            serde_json::from_reader(
561                File::open(order.clone())
562                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
563            )?,
564        )
565        .build(),
566        ModelSelected::GGML {
567            tok_model_id,
568            tokenizer_json,
569            quantized_model_id,
570            quantized_filename,
571            gqa,
572            topology,
573            ..
574        } => GGMLLoaderBuilder::new(
575            GGMLSpecificConfig {
576                gqa,
577                prompt_chunksize: args.prompt_chunksize,
578                topology: Topology::from_option_path(topology)?,
579            },
580            args.chat_template,
581            tokenizer_json,
582            Some(tok_model_id),
583            quantized_model_id,
584            quantized_filename,
585            args.no_kv_cache,
586            args.jinja_explicit,
587        )
588        .build(),
589        ModelSelected::XLoraGGML {
590            tok_model_id,
591            tokenizer_json,
592            quantized_model_id,
593            quantized_filename,
594            xlora_model_id,
595            order,
596            tgt_non_granular_index,
597            gqa,
598            topology,
599            ..
600        } => GGMLLoaderBuilder::new(
601            GGMLSpecificConfig {
602                gqa,
603                prompt_chunksize: args.prompt_chunksize,
604                topology: Topology::from_option_path(topology)?,
605            },
606            args.chat_template,
607            tokenizer_json,
608            tok_model_id,
609            quantized_model_id,
610            quantized_filename,
611            args.no_kv_cache,
612            args.jinja_explicit,
613        )
614        .with_xlora(
615            xlora_model_id,
616            serde_json::from_reader(
617                File::open(order.clone())
618                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
619            )?,
620            args.no_kv_cache,
621            tgt_non_granular_index,
622        )
623        .build(),
624        ModelSelected::LoraGGML {
625            tok_model_id,
626            tokenizer_json,
627            quantized_model_id,
628            quantized_filename,
629            adapters_model_id,
630            order,
631            gqa,
632            topology,
633            ..
634        } => GGMLLoaderBuilder::new(
635            GGMLSpecificConfig {
636                gqa,
637                prompt_chunksize: args.prompt_chunksize,
638                topology: Topology::from_option_path(topology)?,
639            },
640            args.chat_template,
641            tokenizer_json,
642            tok_model_id,
643            quantized_model_id,
644            quantized_filename,
645            args.no_kv_cache,
646            args.jinja_explicit,
647        )
648        .with_lora(
649            adapters_model_id,
650            serde_json::from_reader(
651                File::open(order.clone())
652                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
653            )?,
654        )
655        .build(),
656        ModelSelected::MultiModel { .. } => {
657            anyhow::bail!("MultiModel variant should not be used in model loading functions")
658        }
659    };
660    Ok(loader)
661}