mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    path::PathBuf,
4    str::FromStr,
5};
6
7use mistralrs_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10    get_toml_selected_model_dtype,
11    pipeline::{
12        AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13        GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
14        VisionLoaderBuilder, VisionSpecificConfig,
15    },
16    toml_selector::get_toml_selected_model_device_map_params,
17    AutoDeviceMapParams, Loader, ModelDType, ModelSelected, SpeechLoader, TomlLoaderArgs,
18    TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
19};
20
21/// A builder for a loader using the selected model.
22pub struct LoaderBuilder {
23    model: ModelSelected,
24    no_kv_cache: bool,
25    chat_template: Option<String>,
26    jinja_explicit: Option<String>,
27}
28
29impl LoaderBuilder {
30    pub fn new(model: ModelSelected) -> Self {
31        Self {
32            model,
33            no_kv_cache: false,
34            chat_template: None,
35            jinja_explicit: None,
36        }
37    }
38
39    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
40        self.no_kv_cache = no_kv_cache;
41        self
42    }
43    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
44        self.chat_template = chat_template;
45        self
46    }
47    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
48        self.jinja_explicit = jinja_explicit;
49        self
50    }
51
52    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
53        loader_from_model_selected(self)
54    }
55}
56
57pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
58    match model {
59        ModelSelected::Plain { .. }
60        | ModelSelected::Run { .. }
61        | ModelSelected::Lora { .. }
62        | ModelSelected::GGUF { .. }
63        | ModelSelected::LoraGGUF { .. }
64        | ModelSelected::GGML { .. }
65        | ModelSelected::LoraGGML { .. }
66        | ModelSelected::Toml { .. }
67        | ModelSelected::VisionPlain { .. }
68        | ModelSelected::DiffusionPlain { .. }
69        | ModelSelected::Speech { .. } => None,
70        ModelSelected::XLora {
71            tgt_non_granular_index,
72            ..
73        }
74        | ModelSelected::XLoraGGUF {
75            tgt_non_granular_index,
76            ..
77        }
78        | ModelSelected::XLoraGGML {
79            tgt_non_granular_index,
80            ..
81        } => *tgt_non_granular_index,
82        ModelSelected::MultiModel { .. } => {
83            panic!("MultiModel variant should not be used in model loading functions")
84        }
85    }
86}
87
88pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
89    match model {
90        ModelSelected::Plain { dtype, .. }
91        | ModelSelected::Lora { dtype, .. }
92        | ModelSelected::XLora { dtype, .. }
93        | ModelSelected::VisionPlain { dtype, .. }
94        | ModelSelected::DiffusionPlain { dtype, .. }
95        | ModelSelected::GGML { dtype, .. }
96        | ModelSelected::GGUF { dtype, .. }
97        | ModelSelected::XLoraGGUF { dtype, .. }
98        | ModelSelected::XLoraGGML { dtype, .. }
99        | ModelSelected::LoraGGUF { dtype, .. }
100        | ModelSelected::LoraGGML { dtype, .. }
101        | ModelSelected::Run { dtype, .. }
102        | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
103        ModelSelected::Toml { file } => {
104            let selector: TomlSelector = toml::from_str(
105                &fs::read_to_string(file.clone())
106                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
107            )?;
108            Ok(get_toml_selected_model_dtype(&selector))
109        }
110        ModelSelected::MultiModel { .. } => {
111            anyhow::bail!("MultiModel variant should not be used in model loading functions")
112        }
113    }
114}
115
116pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
117    match model {
118        ModelSelected::Plain {
119            max_seq_len,
120            max_batch_size,
121            ..
122        }
123        | ModelSelected::Lora {
124            max_seq_len,
125            max_batch_size,
126            ..
127        }
128        | ModelSelected::XLora {
129            max_seq_len,
130            max_batch_size,
131            ..
132        }
133        | ModelSelected::GGML {
134            max_seq_len,
135            max_batch_size,
136            ..
137        }
138        | ModelSelected::GGUF {
139            max_seq_len,
140            max_batch_size,
141            ..
142        }
143        | ModelSelected::XLoraGGUF {
144            max_seq_len,
145            max_batch_size,
146            ..
147        }
148        | ModelSelected::XLoraGGML {
149            max_seq_len,
150            max_batch_size,
151            ..
152        }
153        | ModelSelected::LoraGGUF {
154            max_seq_len,
155            max_batch_size,
156            ..
157        }
158        | ModelSelected::LoraGGML {
159            max_seq_len,
160            max_batch_size,
161            ..
162        } => Ok(AutoDeviceMapParams::Text {
163            max_seq_len: *max_seq_len,
164            max_batch_size: *max_batch_size,
165        }),
166        ModelSelected::Run {
167            max_seq_len,
168            max_batch_size,
169            max_image_length,
170            max_num_images,
171            ..
172        } => {
173            if max_num_images.is_some() || max_image_length.is_some() {
174                let max_image_length =
175                    max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
176                Ok(AutoDeviceMapParams::Vision {
177                    max_seq_len: *max_seq_len,
178                    max_batch_size: *max_batch_size,
179                    max_image_shape: (max_image_length, max_image_length),
180                    max_num_images: max_num_images
181                        .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
182                })
183            } else {
184                Ok(AutoDeviceMapParams::Text {
185                    max_seq_len: *max_seq_len,
186                    max_batch_size: *max_batch_size,
187                })
188            }
189        }
190        ModelSelected::VisionPlain {
191            max_seq_len,
192            max_batch_size,
193            max_image_length,
194            max_num_images,
195            ..
196        } => Ok(AutoDeviceMapParams::Vision {
197            max_seq_len: *max_seq_len,
198            max_batch_size: *max_batch_size,
199            max_image_shape: (*max_image_length, *max_image_length),
200            max_num_images: *max_num_images,
201        }),
202        ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
203            Ok(AutoDeviceMapParams::default_text())
204        }
205        ModelSelected::Toml { file } => {
206            let selector: TomlSelector = toml::from_str(
207                &fs::read_to_string(file.clone())
208                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
209            )?;
210            get_toml_selected_model_device_map_params(&selector)
211        }
212        ModelSelected::MultiModel { .. } => {
213            anyhow::bail!("MultiModel variant should not be used in model loading functions")
214        }
215    }
216}
217
218fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
219    let loader: Box<dyn Loader> = match args.model {
220        ModelSelected::Toml { file } => {
221            let selector: TomlSelector = toml::from_str(
222                &fs::read_to_string(file.clone())
223                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
224            )?;
225            let args = TomlLoaderArgs {
226                chat_template: args.chat_template,
227                no_kv_cache: args.no_kv_cache,
228                jinja_explicit: args.jinja_explicit,
229            };
230            (selector, args).try_into()?
231        }
232        ModelSelected::Plain {
233            model_id,
234            tokenizer_json,
235            arch,
236            dtype: _,
237            topology,
238            organization,
239            write_uqff,
240            from_uqff,
241            imatrix,
242            calibration_file,
243            max_seq_len: _,
244            max_batch_size: _,
245            hf_cache_path,
246            matformer_config_path,
247            matformer_slice_name,
248        } => NormalLoaderBuilder::new(
249            NormalSpecificConfig {
250                topology: Topology::from_option_path(topology)?,
251                organization: organization.unwrap_or_default(),
252                write_uqff,
253                from_uqff: from_uqff.map(|x| {
254                    x.split(UQFF_MULTI_FILE_DELIMITER)
255                        .map(PathBuf::from_str)
256                        .map(|x| x.unwrap())
257                        .collect::<Vec<_>>()
258                }),
259                imatrix,
260                calibration_file,
261                hf_cache_path,
262                matformer_config_path,
263                matformer_slice_name,
264            },
265            args.chat_template,
266            tokenizer_json,
267            Some(model_id),
268            args.no_kv_cache,
269            args.jinja_explicit,
270        )
271        .build(arch)?,
272        ModelSelected::Run {
273            model_id,
274            tokenizer_json,
275            dtype: _,
276            topology,
277            organization,
278            write_uqff,
279            from_uqff,
280            imatrix,
281            calibration_file,
282            max_edge,
283            max_seq_len: _,
284            max_batch_size: _,
285            max_num_images: _,
286            max_image_length: _,
287            hf_cache_path,
288            matformer_config_path,
289            matformer_slice_name,
290        } => {
291            let builder = AutoLoaderBuilder::new(
292                NormalSpecificConfig {
293                    topology: Topology::from_option_path(topology.clone())?,
294                    organization: organization.unwrap_or_default(),
295                    write_uqff: write_uqff.clone(),
296                    from_uqff: from_uqff.clone().map(|x| {
297                        x.split(UQFF_MULTI_FILE_DELIMITER)
298                            .map(PathBuf::from_str)
299                            .map(|x| x.unwrap())
300                            .collect::<Vec<_>>()
301                    }),
302                    imatrix: imatrix.clone(),
303                    calibration_file: calibration_file.clone(),
304                    hf_cache_path: hf_cache_path.clone(),
305                    matformer_config_path: matformer_config_path.clone(),
306                    matformer_slice_name: matformer_slice_name.clone(),
307                },
308                VisionSpecificConfig {
309                    topology: Topology::from_option_path(topology)?,
310                    write_uqff,
311                    from_uqff: from_uqff.map(|x| {
312                        x.split(UQFF_MULTI_FILE_DELIMITER)
313                            .map(PathBuf::from_str)
314                            .map(|x| x.unwrap())
315                            .collect::<Vec<_>>()
316                    }),
317                    max_edge,
318                    calibration_file,
319                    imatrix,
320                    hf_cache_path: hf_cache_path.clone(),
321                    matformer_config_path,
322                    matformer_slice_name,
323                },
324                args.chat_template,
325                tokenizer_json,
326                model_id,
327                args.no_kv_cache,
328                args.jinja_explicit,
329            );
330            let builder = if let Some(ref path) = hf_cache_path {
331                builder.hf_cache_path(path.clone())
332            } else {
333                builder
334            };
335            builder.build()
336        }
337        ModelSelected::VisionPlain {
338            model_id,
339            tokenizer_json,
340            arch,
341            dtype: _,
342            topology,
343            write_uqff,
344            from_uqff,
345            max_edge,
346            calibration_file,
347            max_seq_len: _,
348            max_batch_size: _,
349            max_num_images: _,
350            max_image_length: _,
351            hf_cache_path,
352            imatrix,
353            matformer_config_path,
354            matformer_slice_name,
355        } => VisionLoaderBuilder::new(
356            VisionSpecificConfig {
357                topology: Topology::from_option_path(topology)?,
358                write_uqff,
359                from_uqff: from_uqff.map(|x| {
360                    x.split(UQFF_MULTI_FILE_DELIMITER)
361                        .map(PathBuf::from_str)
362                        .map(|x| x.unwrap())
363                        .collect::<Vec<_>>()
364                }),
365                max_edge,
366                calibration_file,
367                imatrix,
368                hf_cache_path,
369                matformer_config_path,
370                matformer_slice_name,
371            },
372            args.chat_template,
373            tokenizer_json,
374            Some(model_id),
375            args.jinja_explicit,
376        )
377        .build(arch),
378        ModelSelected::DiffusionPlain {
379            model_id,
380            arch,
381            dtype: _,
382        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
383        ModelSelected::Speech {
384            model_id,
385            dac_model_id,
386            arch,
387            ..
388        } => Box::new(SpeechLoader {
389            model_id,
390            dac_model_id,
391            arch,
392            cfg: None,
393        }),
394        ModelSelected::XLora {
395            model_id,
396            xlora_model_id,
397            order,
398            tokenizer_json,
399            tgt_non_granular_index,
400            arch,
401            dtype: _,
402            topology,
403            write_uqff,
404            from_uqff,
405            max_seq_len: _,
406            max_batch_size: _,
407            hf_cache_path,
408        } => NormalLoaderBuilder::new(
409            NormalSpecificConfig {
410                topology: Topology::from_option_path(topology)?,
411                organization: Default::default(),
412                write_uqff,
413                from_uqff: from_uqff.map(|x| {
414                    x.split(UQFF_MULTI_FILE_DELIMITER)
415                        .map(PathBuf::from_str)
416                        .map(|x| x.unwrap())
417                        .collect::<Vec<_>>()
418                }),
419                imatrix: None,
420                calibration_file: None,
421                hf_cache_path,
422                matformer_config_path: None,
423                matformer_slice_name: None,
424            },
425            args.chat_template,
426            tokenizer_json,
427            model_id,
428            args.no_kv_cache,
429            args.jinja_explicit,
430        )
431        .with_xlora(
432            xlora_model_id,
433            serde_json::from_reader(
434                File::open(order.clone())
435                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
436            )?,
437            args.no_kv_cache,
438            tgt_non_granular_index,
439        )
440        .build(arch)?,
441        ModelSelected::Lora {
442            model_id,
443            tokenizer_json,
444            adapter_model_id,
445            arch,
446            dtype: _,
447            topology,
448            write_uqff,
449            from_uqff,
450            max_seq_len: _,
451            max_batch_size: _,
452            hf_cache_path,
453        } => NormalLoaderBuilder::new(
454            NormalSpecificConfig {
455                topology: Topology::from_option_path(topology)?,
456                organization: Default::default(),
457                write_uqff,
458                from_uqff: from_uqff.map(|x| {
459                    x.split(UQFF_MULTI_FILE_DELIMITER)
460                        .map(PathBuf::from_str)
461                        .map(|x| x.unwrap())
462                        .collect::<Vec<_>>()
463                }),
464                imatrix: None,
465                calibration_file: None,
466                hf_cache_path,
467                matformer_config_path: None,
468                matformer_slice_name: None,
469            },
470            args.chat_template,
471            tokenizer_json,
472            model_id,
473            args.no_kv_cache,
474            args.jinja_explicit,
475        )
476        .with_lora(
477            adapter_model_id
478                .split(MULTI_LORA_DELIMITER)
479                .map(ToString::to_string)
480                .collect(),
481        )
482        .build(arch)?,
483        ModelSelected::GGUF {
484            tok_model_id,
485            quantized_model_id,
486            quantized_filename,
487            topology,
488            ..
489        } => GGUFLoaderBuilder::new(
490            args.chat_template,
491            tok_model_id,
492            quantized_model_id,
493            quantized_filename
494                .split(GGUF_MULTI_FILE_DELIMITER)
495                .map(ToOwned::to_owned)
496                .collect::<Vec<_>>(),
497            GGUFSpecificConfig {
498                topology: Topology::from_option_path(topology)?,
499            },
500            args.no_kv_cache,
501            args.jinja_explicit,
502        )
503        .build(),
504        ModelSelected::XLoraGGUF {
505            tok_model_id,
506            quantized_model_id,
507            quantized_filename,
508            xlora_model_id,
509            order,
510            tgt_non_granular_index,
511            topology,
512            ..
513        } => GGUFLoaderBuilder::new(
514            args.chat_template,
515            tok_model_id,
516            quantized_model_id,
517            quantized_filename
518                .split(GGUF_MULTI_FILE_DELIMITER)
519                .map(ToOwned::to_owned)
520                .collect::<Vec<_>>(),
521            GGUFSpecificConfig {
522                topology: Topology::from_option_path(topology)?,
523            },
524            args.no_kv_cache,
525            args.jinja_explicit,
526        )
527        .with_xlora(
528            xlora_model_id,
529            serde_json::from_reader(
530                File::open(order.clone())
531                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
532            )?,
533            args.no_kv_cache,
534            tgt_non_granular_index,
535        )
536        .build(),
537        ModelSelected::LoraGGUF {
538            tok_model_id,
539            quantized_model_id,
540            quantized_filename,
541            adapters_model_id,
542            order,
543            topology,
544            ..
545        } => GGUFLoaderBuilder::new(
546            args.chat_template,
547            tok_model_id,
548            quantized_model_id,
549            quantized_filename
550                .split(GGUF_MULTI_FILE_DELIMITER)
551                .map(ToOwned::to_owned)
552                .collect::<Vec<_>>(),
553            GGUFSpecificConfig {
554                topology: Topology::from_option_path(topology)?,
555            },
556            args.no_kv_cache,
557            args.jinja_explicit,
558        )
559        .with_lora(
560            adapters_model_id,
561            serde_json::from_reader(
562                File::open(order.clone())
563                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
564            )?,
565        )
566        .build(),
567        ModelSelected::GGML {
568            tok_model_id,
569            tokenizer_json,
570            quantized_model_id,
571            quantized_filename,
572            gqa,
573            topology,
574            ..
575        } => GGMLLoaderBuilder::new(
576            GGMLSpecificConfig {
577                gqa,
578                topology: Topology::from_option_path(topology)?,
579            },
580            args.chat_template,
581            tokenizer_json,
582            Some(tok_model_id),
583            quantized_model_id,
584            quantized_filename,
585            args.no_kv_cache,
586            args.jinja_explicit,
587        )
588        .build(),
589        ModelSelected::XLoraGGML {
590            tok_model_id,
591            tokenizer_json,
592            quantized_model_id,
593            quantized_filename,
594            xlora_model_id,
595            order,
596            tgt_non_granular_index,
597            gqa,
598            topology,
599            ..
600        } => GGMLLoaderBuilder::new(
601            GGMLSpecificConfig {
602                gqa,
603                topology: Topology::from_option_path(topology)?,
604            },
605            args.chat_template,
606            tokenizer_json,
607            tok_model_id,
608            quantized_model_id,
609            quantized_filename,
610            args.no_kv_cache,
611            args.jinja_explicit,
612        )
613        .with_xlora(
614            xlora_model_id,
615            serde_json::from_reader(
616                File::open(order.clone())
617                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
618            )?,
619            args.no_kv_cache,
620            tgt_non_granular_index,
621        )
622        .build(),
623        ModelSelected::LoraGGML {
624            tok_model_id,
625            tokenizer_json,
626            quantized_model_id,
627            quantized_filename,
628            adapters_model_id,
629            order,
630            gqa,
631            topology,
632            ..
633        } => GGMLLoaderBuilder::new(
634            GGMLSpecificConfig {
635                gqa,
636                topology: Topology::from_option_path(topology)?,
637            },
638            args.chat_template,
639            tokenizer_json,
640            tok_model_id,
641            quantized_model_id,
642            quantized_filename,
643            args.no_kv_cache,
644            args.jinja_explicit,
645        )
646        .with_lora(
647            adapters_model_id,
648            serde_json::from_reader(
649                File::open(order.clone())
650                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
651            )?,
652        )
653        .build(),
654        ModelSelected::MultiModel { .. } => {
655            anyhow::bail!("MultiModel variant should not be used in model loading functions")
656        }
657    };
658    Ok(loader)
659}