mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    num::NonZeroUsize,
4    path::PathBuf,
5    str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11    get_toml_selected_model_dtype,
12    pipeline::{
13        AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
14        GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
15        VisionLoaderBuilder, VisionSpecificConfig,
16    },
17    toml_selector::get_toml_selected_model_device_map_params,
18    AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
19    TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
20};
21
22/// A builder for a loader using the selected model.
23pub struct LoaderBuilder {
24    model: ModelSelected,
25    no_kv_cache: bool,
26    chat_template: Option<String>,
27    jinja_explicit: Option<String>,
28    prompt_chunksize: Option<NonZeroUsize>,
29}
30
31impl LoaderBuilder {
32    pub fn new(model: ModelSelected) -> Self {
33        Self {
34            model,
35            no_kv_cache: false,
36            chat_template: None,
37            prompt_chunksize: None,
38            jinja_explicit: None,
39        }
40    }
41
42    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43        self.no_kv_cache = no_kv_cache;
44        self
45    }
46    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47        self.chat_template = chat_template;
48        self
49    }
50    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51        self.jinja_explicit = jinja_explicit;
52        self
53    }
54    pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
55        self.prompt_chunksize = prompt_chunksize;
56        self
57    }
58
59    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
60        loader_from_model_selected(self)
61    }
62}
63
64pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
65    match model {
66        ModelSelected::Plain { .. }
67        | ModelSelected::Run { .. }
68        | ModelSelected::Lora { .. }
69        | ModelSelected::GGUF { .. }
70        | ModelSelected::LoraGGUF { .. }
71        | ModelSelected::GGML { .. }
72        | ModelSelected::LoraGGML { .. }
73        | ModelSelected::Toml { .. }
74        | ModelSelected::VisionPlain { .. }
75        | ModelSelected::DiffusionPlain { .. }
76        | ModelSelected::Speech { .. } => None,
77        ModelSelected::XLora {
78            tgt_non_granular_index,
79            ..
80        }
81        | ModelSelected::XLoraGGUF {
82            tgt_non_granular_index,
83            ..
84        }
85        | ModelSelected::XLoraGGML {
86            tgt_non_granular_index,
87            ..
88        } => *tgt_non_granular_index,
89        ModelSelected::MultiModel { .. } => {
90            panic!("MultiModel variant should not be used in model loading functions")
91        }
92    }
93}
94
95pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
96    match model {
97        ModelSelected::Plain { dtype, .. }
98        | ModelSelected::Lora { dtype, .. }
99        | ModelSelected::XLora { dtype, .. }
100        | ModelSelected::VisionPlain { dtype, .. }
101        | ModelSelected::DiffusionPlain { dtype, .. }
102        | ModelSelected::GGML { dtype, .. }
103        | ModelSelected::GGUF { dtype, .. }
104        | ModelSelected::XLoraGGUF { dtype, .. }
105        | ModelSelected::XLoraGGML { dtype, .. }
106        | ModelSelected::LoraGGUF { dtype, .. }
107        | ModelSelected::LoraGGML { dtype, .. }
108        | ModelSelected::Run { dtype, .. }
109        | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
110        ModelSelected::Toml { file } => {
111            let selector: TomlSelector = toml::from_str(
112                &fs::read_to_string(file.clone())
113                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
114            )?;
115            Ok(get_toml_selected_model_dtype(&selector))
116        }
117        ModelSelected::MultiModel { .. } => {
118            anyhow::bail!("MultiModel variant should not be used in model loading functions")
119        }
120    }
121}
122
123pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
124    match model {
125        ModelSelected::Plain {
126            max_seq_len,
127            max_batch_size,
128            ..
129        }
130        | ModelSelected::Lora {
131            max_seq_len,
132            max_batch_size,
133            ..
134        }
135        | ModelSelected::XLora {
136            max_seq_len,
137            max_batch_size,
138            ..
139        }
140        | ModelSelected::GGML {
141            max_seq_len,
142            max_batch_size,
143            ..
144        }
145        | ModelSelected::GGUF {
146            max_seq_len,
147            max_batch_size,
148            ..
149        }
150        | ModelSelected::XLoraGGUF {
151            max_seq_len,
152            max_batch_size,
153            ..
154        }
155        | ModelSelected::XLoraGGML {
156            max_seq_len,
157            max_batch_size,
158            ..
159        }
160        | ModelSelected::LoraGGUF {
161            max_seq_len,
162            max_batch_size,
163            ..
164        }
165        | ModelSelected::LoraGGML {
166            max_seq_len,
167            max_batch_size,
168            ..
169        } => Ok(AutoDeviceMapParams::Text {
170            max_seq_len: *max_seq_len,
171            max_batch_size: *max_batch_size,
172        }),
173        ModelSelected::Run {
174            max_seq_len,
175            max_batch_size,
176            max_image_length,
177            max_num_images,
178            ..
179        } => {
180            if max_num_images.is_some() || max_image_length.is_some() {
181                let max_image_length =
182                    max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
183                Ok(AutoDeviceMapParams::Vision {
184                    max_seq_len: *max_seq_len,
185                    max_batch_size: *max_batch_size,
186                    max_image_shape: (max_image_length, max_image_length),
187                    max_num_images: max_num_images
188                        .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
189                })
190            } else {
191                Ok(AutoDeviceMapParams::Text {
192                    max_seq_len: *max_seq_len,
193                    max_batch_size: *max_batch_size,
194                })
195            }
196        }
197        ModelSelected::VisionPlain {
198            max_seq_len,
199            max_batch_size,
200            max_image_length,
201            max_num_images,
202            ..
203        } => Ok(AutoDeviceMapParams::Vision {
204            max_seq_len: *max_seq_len,
205            max_batch_size: *max_batch_size,
206            max_image_shape: (*max_image_length, *max_image_length),
207            max_num_images: *max_num_images,
208        }),
209        ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
210            Ok(AutoDeviceMapParams::default_text())
211        }
212        ModelSelected::Toml { file } => {
213            let selector: TomlSelector = toml::from_str(
214                &fs::read_to_string(file.clone())
215                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
216            )?;
217            get_toml_selected_model_device_map_params(&selector)
218        }
219        ModelSelected::MultiModel { .. } => {
220            anyhow::bail!("MultiModel variant should not be used in model loading functions")
221        }
222    }
223}
224
225fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
226    let loader: Box<dyn Loader> = match args.model {
227        ModelSelected::Toml { file } => {
228            let selector: TomlSelector = toml::from_str(
229                &fs::read_to_string(file.clone())
230                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
231            )?;
232            let args = TomlLoaderArgs {
233                chat_template: args.chat_template,
234                no_kv_cache: args.no_kv_cache,
235                prompt_chunksize: args.prompt_chunksize,
236                jinja_explicit: args.jinja_explicit,
237            };
238            (selector, args).try_into()?
239        }
240        ModelSelected::Plain {
241            model_id,
242            tokenizer_json,
243            arch,
244            dtype: _,
245            topology,
246            organization,
247            write_uqff,
248            from_uqff,
249            imatrix,
250            calibration_file,
251            max_seq_len: _,
252            max_batch_size: _,
253            hf_cache_path,
254            matformer_config_path,
255            matformer_slice_name,
256        } => NormalLoaderBuilder::new(
257            NormalSpecificConfig {
258                prompt_chunksize: args.prompt_chunksize,
259                topology: Topology::from_option_path(topology)?,
260                organization: organization.unwrap_or_default(),
261                write_uqff,
262                from_uqff: from_uqff.map(|x| {
263                    x.split(UQFF_MULTI_FILE_DELIMITER)
264                        .map(PathBuf::from_str)
265                        .map(|x| x.unwrap())
266                        .collect::<Vec<_>>()
267                }),
268                imatrix,
269                calibration_file,
270                hf_cache_path,
271                matformer_config_path,
272                matformer_slice_name,
273            },
274            args.chat_template,
275            tokenizer_json,
276            Some(model_id),
277            args.no_kv_cache,
278            args.jinja_explicit,
279        )
280        .build(arch)?,
281        ModelSelected::Run {
282            model_id,
283            tokenizer_json,
284            dtype: _,
285            topology,
286            organization,
287            write_uqff,
288            from_uqff,
289            imatrix,
290            calibration_file,
291            max_edge,
292            max_seq_len: _,
293            max_batch_size: _,
294            max_num_images: _,
295            max_image_length: _,
296            hf_cache_path,
297            matformer_config_path,
298            matformer_slice_name,
299        } => {
300            let builder = AutoLoaderBuilder::new(
301                NormalSpecificConfig {
302                    prompt_chunksize: args.prompt_chunksize,
303                    topology: Topology::from_option_path(topology.clone())?,
304                    organization: organization.unwrap_or_default(),
305                    write_uqff: write_uqff.clone(),
306                    from_uqff: from_uqff.clone().map(|x| {
307                        x.split(UQFF_MULTI_FILE_DELIMITER)
308                            .map(PathBuf::from_str)
309                            .map(|x| x.unwrap())
310                            .collect::<Vec<_>>()
311                    }),
312                    imatrix: imatrix.clone(),
313                    calibration_file: calibration_file.clone(),
314                    hf_cache_path: hf_cache_path.clone(),
315                    matformer_config_path: matformer_config_path.clone(),
316                    matformer_slice_name: matformer_slice_name.clone(),
317                },
318                VisionSpecificConfig {
319                    prompt_chunksize: args.prompt_chunksize,
320                    topology: Topology::from_option_path(topology)?,
321                    write_uqff,
322                    from_uqff: from_uqff.map(|x| {
323                        x.split(UQFF_MULTI_FILE_DELIMITER)
324                            .map(PathBuf::from_str)
325                            .map(|x| x.unwrap())
326                            .collect::<Vec<_>>()
327                    }),
328                    max_edge,
329                    calibration_file,
330                    imatrix,
331                    hf_cache_path: hf_cache_path.clone(),
332                    matformer_config_path,
333                    matformer_slice_name,
334                },
335                args.chat_template,
336                tokenizer_json,
337                model_id,
338                args.no_kv_cache,
339                args.jinja_explicit,
340            );
341            let builder = if let Some(ref path) = hf_cache_path {
342                builder.hf_cache_path(path.clone())
343            } else {
344                builder
345            };
346            builder.build()
347        }
348        ModelSelected::VisionPlain {
349            model_id,
350            tokenizer_json,
351            arch,
352            dtype: _,
353            topology,
354            write_uqff,
355            from_uqff,
356            max_edge,
357            calibration_file,
358            max_seq_len: _,
359            max_batch_size: _,
360            max_num_images: _,
361            max_image_length: _,
362            hf_cache_path,
363            imatrix,
364            matformer_config_path,
365            matformer_slice_name,
366        } => VisionLoaderBuilder::new(
367            VisionSpecificConfig {
368                prompt_chunksize: args.prompt_chunksize,
369                topology: Topology::from_option_path(topology)?,
370                write_uqff,
371                from_uqff: from_uqff.map(|x| {
372                    x.split(UQFF_MULTI_FILE_DELIMITER)
373                        .map(PathBuf::from_str)
374                        .map(|x| x.unwrap())
375                        .collect::<Vec<_>>()
376                }),
377                max_edge,
378                calibration_file,
379                imatrix,
380                hf_cache_path,
381                matformer_config_path,
382                matformer_slice_name,
383            },
384            args.chat_template,
385            tokenizer_json,
386            Some(model_id),
387            args.jinja_explicit,
388        )
389        .build(arch),
390        ModelSelected::DiffusionPlain {
391            model_id,
392            arch,
393            dtype: _,
394        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
395        ModelSelected::Speech {
396            model_id,
397            dac_model_id,
398            arch,
399            ..
400        } => Box::new(SpeechLoader {
401            model_id,
402            dac_model_id,
403            arch,
404            cfg: None,
405        }),
406        ModelSelected::XLora {
407            model_id,
408            xlora_model_id,
409            order,
410            tokenizer_json,
411            tgt_non_granular_index,
412            arch,
413            dtype: _,
414            topology,
415            write_uqff,
416            from_uqff,
417            max_seq_len: _,
418            max_batch_size: _,
419            hf_cache_path,
420        } => NormalLoaderBuilder::new(
421            NormalSpecificConfig {
422                prompt_chunksize: args.prompt_chunksize,
423                topology: Topology::from_option_path(topology)?,
424                organization: Default::default(),
425                write_uqff,
426                from_uqff: from_uqff.map(|x| {
427                    x.split(UQFF_MULTI_FILE_DELIMITER)
428                        .map(PathBuf::from_str)
429                        .map(|x| x.unwrap())
430                        .collect::<Vec<_>>()
431                }),
432                imatrix: None,
433                calibration_file: None,
434                hf_cache_path,
435                matformer_config_path: None,
436                matformer_slice_name: None,
437            },
438            args.chat_template,
439            tokenizer_json,
440            model_id,
441            args.no_kv_cache,
442            args.jinja_explicit,
443        )
444        .with_xlora(
445            xlora_model_id,
446            serde_json::from_reader(
447                File::open(order.clone())
448                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
449            )?,
450            args.no_kv_cache,
451            tgt_non_granular_index,
452        )
453        .build(arch)?,
454        ModelSelected::Lora {
455            model_id,
456            tokenizer_json,
457            adapter_model_id,
458            arch,
459            dtype: _,
460            topology,
461            write_uqff,
462            from_uqff,
463            max_seq_len: _,
464            max_batch_size: _,
465            hf_cache_path,
466        } => NormalLoaderBuilder::new(
467            NormalSpecificConfig {
468                prompt_chunksize: args.prompt_chunksize,
469                topology: Topology::from_option_path(topology)?,
470                organization: Default::default(),
471                write_uqff,
472                from_uqff: from_uqff.map(|x| {
473                    x.split(UQFF_MULTI_FILE_DELIMITER)
474                        .map(PathBuf::from_str)
475                        .map(|x| x.unwrap())
476                        .collect::<Vec<_>>()
477                }),
478                imatrix: None,
479                calibration_file: None,
480                hf_cache_path,
481                matformer_config_path: None,
482                matformer_slice_name: None,
483            },
484            args.chat_template,
485            tokenizer_json,
486            model_id,
487            args.no_kv_cache,
488            args.jinja_explicit,
489        )
490        .with_lora(
491            adapter_model_id
492                .split(MULTI_LORA_DELIMITER)
493                .map(ToString::to_string)
494                .collect(),
495        )
496        .build(arch)?,
497        ModelSelected::GGUF {
498            tok_model_id,
499            quantized_model_id,
500            quantized_filename,
501            topology,
502            ..
503        } => GGUFLoaderBuilder::new(
504            args.chat_template,
505            tok_model_id,
506            quantized_model_id,
507            quantized_filename
508                .split(GGUF_MULTI_FILE_DELIMITER)
509                .map(ToOwned::to_owned)
510                .collect::<Vec<_>>(),
511            GGUFSpecificConfig {
512                prompt_chunksize: args.prompt_chunksize,
513                topology: Topology::from_option_path(topology)?,
514            },
515            args.no_kv_cache,
516            args.jinja_explicit,
517        )
518        .build(),
519        ModelSelected::XLoraGGUF {
520            tok_model_id,
521            quantized_model_id,
522            quantized_filename,
523            xlora_model_id,
524            order,
525            tgt_non_granular_index,
526            topology,
527            ..
528        } => GGUFLoaderBuilder::new(
529            args.chat_template,
530            tok_model_id,
531            quantized_model_id,
532            quantized_filename
533                .split(GGUF_MULTI_FILE_DELIMITER)
534                .map(ToOwned::to_owned)
535                .collect::<Vec<_>>(),
536            GGUFSpecificConfig {
537                prompt_chunksize: args.prompt_chunksize,
538                topology: Topology::from_option_path(topology)?,
539            },
540            args.no_kv_cache,
541            args.jinja_explicit,
542        )
543        .with_xlora(
544            xlora_model_id,
545            serde_json::from_reader(
546                File::open(order.clone())
547                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
548            )?,
549            args.no_kv_cache,
550            tgt_non_granular_index,
551        )
552        .build(),
553        ModelSelected::LoraGGUF {
554            tok_model_id,
555            quantized_model_id,
556            quantized_filename,
557            adapters_model_id,
558            order,
559            topology,
560            ..
561        } => GGUFLoaderBuilder::new(
562            args.chat_template,
563            tok_model_id,
564            quantized_model_id,
565            quantized_filename
566                .split(GGUF_MULTI_FILE_DELIMITER)
567                .map(ToOwned::to_owned)
568                .collect::<Vec<_>>(),
569            GGUFSpecificConfig {
570                prompt_chunksize: args.prompt_chunksize,
571                topology: Topology::from_option_path(topology)?,
572            },
573            args.no_kv_cache,
574            args.jinja_explicit,
575        )
576        .with_lora(
577            adapters_model_id,
578            serde_json::from_reader(
579                File::open(order.clone())
580                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
581            )?,
582        )
583        .build(),
584        ModelSelected::GGML {
585            tok_model_id,
586            tokenizer_json,
587            quantized_model_id,
588            quantized_filename,
589            gqa,
590            topology,
591            ..
592        } => GGMLLoaderBuilder::new(
593            GGMLSpecificConfig {
594                gqa,
595                prompt_chunksize: args.prompt_chunksize,
596                topology: Topology::from_option_path(topology)?,
597            },
598            args.chat_template,
599            tokenizer_json,
600            Some(tok_model_id),
601            quantized_model_id,
602            quantized_filename,
603            args.no_kv_cache,
604            args.jinja_explicit,
605        )
606        .build(),
607        ModelSelected::XLoraGGML {
608            tok_model_id,
609            tokenizer_json,
610            quantized_model_id,
611            quantized_filename,
612            xlora_model_id,
613            order,
614            tgt_non_granular_index,
615            gqa,
616            topology,
617            ..
618        } => GGMLLoaderBuilder::new(
619            GGMLSpecificConfig {
620                gqa,
621                prompt_chunksize: args.prompt_chunksize,
622                topology: Topology::from_option_path(topology)?,
623            },
624            args.chat_template,
625            tokenizer_json,
626            tok_model_id,
627            quantized_model_id,
628            quantized_filename,
629            args.no_kv_cache,
630            args.jinja_explicit,
631        )
632        .with_xlora(
633            xlora_model_id,
634            serde_json::from_reader(
635                File::open(order.clone())
636                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
637            )?,
638            args.no_kv_cache,
639            tgt_non_granular_index,
640        )
641        .build(),
642        ModelSelected::LoraGGML {
643            tok_model_id,
644            tokenizer_json,
645            quantized_model_id,
646            quantized_filename,
647            adapters_model_id,
648            order,
649            gqa,
650            topology,
651            ..
652        } => GGMLLoaderBuilder::new(
653            GGMLSpecificConfig {
654                gqa,
655                prompt_chunksize: args.prompt_chunksize,
656                topology: Topology::from_option_path(topology)?,
657            },
658            args.chat_template,
659            tokenizer_json,
660            tok_model_id,
661            quantized_model_id,
662            quantized_filename,
663            args.no_kv_cache,
664            args.jinja_explicit,
665        )
666        .with_lora(
667            adapters_model_id,
668            serde_json::from_reader(
669                File::open(order.clone())
670                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
671            )?,
672        )
673        .build(),
674        ModelSelected::MultiModel { .. } => {
675            anyhow::bail!("MultiModel variant should not be used in model loading functions")
676        }
677    };
678    Ok(loader)
679}