mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    num::NonZeroUsize,
4    path::PathBuf,
5    str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11    get_toml_selected_model_dtype,
12    pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
13    toml_selector::get_toml_selected_model_device_map_params,
14    AutoDeviceMapParams, DiffusionLoaderBuilder, GGUFSpecificConfig, Loader, ModelDType,
15    ModelSelected, NormalLoaderBuilder, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology,
16    VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
17    UQFF_MULTI_FILE_DELIMITER,
18};
19
20/// A builder for a loader using the selected model.
21pub struct LoaderBuilder {
22    model: ModelSelected,
23    no_kv_cache: bool,
24    chat_template: Option<String>,
25    jinja_explicit: Option<String>,
26    prompt_chunksize: Option<NonZeroUsize>,
27}
28
29impl LoaderBuilder {
30    pub fn new(model: ModelSelected) -> Self {
31        Self {
32            model,
33            no_kv_cache: false,
34            chat_template: None,
35            prompt_chunksize: None,
36            jinja_explicit: None,
37        }
38    }
39
40    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41        self.no_kv_cache = no_kv_cache;
42        self
43    }
44    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45        self.chat_template = chat_template;
46        self
47    }
48    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49        self.jinja_explicit = jinja_explicit;
50        self
51    }
52    pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
53        self.prompt_chunksize = prompt_chunksize;
54        self
55    }
56
57    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
58        loader_from_model_selected(self)
59    }
60}
61
62pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
63    match model {
64        ModelSelected::Plain { .. }
65        | ModelSelected::Lora { .. }
66        | ModelSelected::GGUF { .. }
67        | ModelSelected::LoraGGUF { .. }
68        | ModelSelected::GGML { .. }
69        | ModelSelected::LoraGGML { .. }
70        | ModelSelected::Toml { .. }
71        | ModelSelected::VisionPlain { .. }
72        | ModelSelected::DiffusionPlain { .. }
73        | ModelSelected::Speech { .. } => None,
74        ModelSelected::XLora {
75            tgt_non_granular_index,
76            ..
77        }
78        | ModelSelected::XLoraGGUF {
79            tgt_non_granular_index,
80            ..
81        }
82        | ModelSelected::XLoraGGML {
83            tgt_non_granular_index,
84            ..
85        } => *tgt_non_granular_index,
86    }
87}
88
89pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
90    match model {
91        ModelSelected::Plain { dtype, .. }
92        | ModelSelected::Lora { dtype, .. }
93        | ModelSelected::XLora { dtype, .. }
94        | ModelSelected::VisionPlain { dtype, .. }
95        | ModelSelected::DiffusionPlain { dtype, .. }
96        | ModelSelected::GGML { dtype, .. }
97        | ModelSelected::GGUF { dtype, .. }
98        | ModelSelected::XLoraGGUF { dtype, .. }
99        | ModelSelected::XLoraGGML { dtype, .. }
100        | ModelSelected::LoraGGUF { dtype, .. }
101        | ModelSelected::LoraGGML { dtype, .. }
102        | ModelSelected::Speech { dtype, .. } => Ok(*dtype),
103        ModelSelected::Toml { file } => {
104            let selector: TomlSelector = toml::from_str(
105                &fs::read_to_string(file.clone())
106                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
107            )?;
108            Ok(get_toml_selected_model_dtype(&selector))
109        }
110    }
111}
112
113pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
114    match model {
115        ModelSelected::Plain {
116            max_seq_len,
117            max_batch_size,
118            ..
119        }
120        | ModelSelected::Lora {
121            max_seq_len,
122            max_batch_size,
123            ..
124        }
125        | ModelSelected::XLora {
126            max_seq_len,
127            max_batch_size,
128            ..
129        }
130        | ModelSelected::GGML {
131            max_seq_len,
132            max_batch_size,
133            ..
134        }
135        | ModelSelected::GGUF {
136            max_seq_len,
137            max_batch_size,
138            ..
139        }
140        | ModelSelected::XLoraGGUF {
141            max_seq_len,
142            max_batch_size,
143            ..
144        }
145        | ModelSelected::XLoraGGML {
146            max_seq_len,
147            max_batch_size,
148            ..
149        }
150        | ModelSelected::LoraGGUF {
151            max_seq_len,
152            max_batch_size,
153            ..
154        }
155        | ModelSelected::LoraGGML {
156            max_seq_len,
157            max_batch_size,
158            ..
159        } => Ok(AutoDeviceMapParams::Text {
160            max_seq_len: *max_seq_len,
161            max_batch_size: *max_batch_size,
162        }),
163        ModelSelected::VisionPlain {
164            max_seq_len,
165            max_batch_size,
166            max_image_length,
167            max_num_images,
168            ..
169        } => Ok(AutoDeviceMapParams::Vision {
170            max_seq_len: *max_seq_len,
171            max_batch_size: *max_batch_size,
172            max_image_shape: (*max_image_length, *max_image_length),
173            max_num_images: *max_num_images,
174        }),
175        ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
176            Ok(AutoDeviceMapParams::default_text())
177        }
178        ModelSelected::Toml { file } => {
179            let selector: TomlSelector = toml::from_str(
180                &fs::read_to_string(file.clone())
181                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
182            )?;
183            get_toml_selected_model_device_map_params(&selector)
184        }
185    }
186}
187
188fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
189    let loader: Box<dyn Loader> = match args.model {
190        ModelSelected::Toml { file } => {
191            let selector: TomlSelector = toml::from_str(
192                &fs::read_to_string(file.clone())
193                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
194            )?;
195            let args = TomlLoaderArgs {
196                chat_template: args.chat_template,
197                no_kv_cache: args.no_kv_cache,
198                prompt_chunksize: args.prompt_chunksize,
199                jinja_explicit: args.jinja_explicit,
200            };
201            (selector, args).try_into()?
202        }
203        ModelSelected::Plain {
204            model_id,
205            tokenizer_json,
206            arch,
207            dtype: _,
208            topology,
209            organization,
210            write_uqff,
211            from_uqff,
212            imatrix,
213            calibration_file,
214            max_seq_len: _,
215            max_batch_size: _,
216            hf_cache_path,
217        } => NormalLoaderBuilder::new(
218            NormalSpecificConfig {
219                prompt_chunksize: args.prompt_chunksize,
220                topology: Topology::from_option_path(topology)?,
221                organization: organization.unwrap_or_default(),
222                write_uqff,
223                from_uqff: from_uqff.map(|x| {
224                    x.split(UQFF_MULTI_FILE_DELIMITER)
225                        .map(PathBuf::from_str)
226                        .map(|x| x.unwrap())
227                        .collect::<Vec<_>>()
228                }),
229                imatrix,
230                calibration_file,
231                hf_cache_path,
232            },
233            args.chat_template,
234            tokenizer_json,
235            Some(model_id),
236            args.no_kv_cache,
237            args.jinja_explicit,
238        )
239        .build(arch)?,
240        ModelSelected::XLora {
241            model_id,
242            xlora_model_id,
243            order,
244            tokenizer_json,
245            tgt_non_granular_index,
246            arch,
247            dtype: _,
248            topology,
249            write_uqff,
250            from_uqff,
251            max_seq_len: _,
252            max_batch_size: _,
253            hf_cache_path,
254        } => NormalLoaderBuilder::new(
255            NormalSpecificConfig {
256                prompt_chunksize: args.prompt_chunksize,
257                topology: Topology::from_option_path(topology)?,
258                organization: Default::default(),
259                write_uqff,
260                from_uqff: from_uqff.map(|x| {
261                    x.split(UQFF_MULTI_FILE_DELIMITER)
262                        .map(PathBuf::from_str)
263                        .map(|x| x.unwrap())
264                        .collect::<Vec<_>>()
265                }),
266                imatrix: None,
267                calibration_file: None,
268                hf_cache_path,
269            },
270            args.chat_template,
271            tokenizer_json,
272            model_id,
273            args.no_kv_cache,
274            args.jinja_explicit,
275        )
276        .with_xlora(
277            xlora_model_id,
278            serde_json::from_reader(
279                File::open(order.clone())
280                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
281            )?,
282            args.no_kv_cache,
283            tgt_non_granular_index,
284        )
285        .build(arch)?,
286        ModelSelected::Lora {
287            model_id,
288            tokenizer_json,
289            adapter_model_id,
290            arch,
291            dtype: _,
292            topology,
293            write_uqff,
294            from_uqff,
295            max_seq_len: _,
296            max_batch_size: _,
297            hf_cache_path,
298        } => NormalLoaderBuilder::new(
299            NormalSpecificConfig {
300                prompt_chunksize: args.prompt_chunksize,
301                topology: Topology::from_option_path(topology)?,
302                organization: Default::default(),
303                write_uqff,
304                from_uqff: from_uqff.map(|x| {
305                    x.split(UQFF_MULTI_FILE_DELIMITER)
306                        .map(PathBuf::from_str)
307                        .map(|x| x.unwrap())
308                        .collect::<Vec<_>>()
309                }),
310                imatrix: None,
311                calibration_file: None,
312                hf_cache_path,
313            },
314            args.chat_template,
315            tokenizer_json,
316            model_id,
317            args.no_kv_cache,
318            args.jinja_explicit,
319        )
320        .with_lora(
321            adapter_model_id
322                .split(MULTI_LORA_DELIMITER)
323                .map(ToString::to_string)
324                .collect(),
325        )
326        .build(arch)?,
327        ModelSelected::GGUF {
328            tok_model_id,
329            quantized_model_id,
330            quantized_filename,
331            topology,
332            ..
333        } => GGUFLoaderBuilder::new(
334            args.chat_template,
335            tok_model_id,
336            quantized_model_id,
337            quantized_filename
338                .split(GGUF_MULTI_FILE_DELIMITER)
339                .map(ToOwned::to_owned)
340                .collect::<Vec<_>>(),
341            GGUFSpecificConfig {
342                prompt_chunksize: args.prompt_chunksize,
343                topology: Topology::from_option_path(topology)?,
344            },
345            args.no_kv_cache,
346            args.jinja_explicit,
347        )
348        .build(),
349        ModelSelected::XLoraGGUF {
350            tok_model_id,
351            quantized_model_id,
352            quantized_filename,
353            xlora_model_id,
354            order,
355            tgt_non_granular_index,
356            topology,
357            ..
358        } => GGUFLoaderBuilder::new(
359            args.chat_template,
360            tok_model_id,
361            quantized_model_id,
362            quantized_filename
363                .split(GGUF_MULTI_FILE_DELIMITER)
364                .map(ToOwned::to_owned)
365                .collect::<Vec<_>>(),
366            GGUFSpecificConfig {
367                prompt_chunksize: args.prompt_chunksize,
368                topology: Topology::from_option_path(topology)?,
369            },
370            args.no_kv_cache,
371            args.jinja_explicit,
372        )
373        .with_xlora(
374            xlora_model_id,
375            serde_json::from_reader(
376                File::open(order.clone())
377                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
378            )?,
379            args.no_kv_cache,
380            tgt_non_granular_index,
381        )
382        .build(),
383        ModelSelected::LoraGGUF {
384            tok_model_id,
385            quantized_model_id,
386            quantized_filename,
387            adapters_model_id,
388            order,
389            topology,
390            ..
391        } => GGUFLoaderBuilder::new(
392            args.chat_template,
393            tok_model_id,
394            quantized_model_id,
395            quantized_filename
396                .split(GGUF_MULTI_FILE_DELIMITER)
397                .map(ToOwned::to_owned)
398                .collect::<Vec<_>>(),
399            GGUFSpecificConfig {
400                prompt_chunksize: args.prompt_chunksize,
401                topology: Topology::from_option_path(topology)?,
402            },
403            args.no_kv_cache,
404            args.jinja_explicit,
405        )
406        .with_lora(
407            adapters_model_id,
408            serde_json::from_reader(
409                File::open(order.clone())
410                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
411            )?,
412        )
413        .build(),
414        ModelSelected::GGML {
415            tok_model_id,
416            tokenizer_json,
417            quantized_model_id,
418            quantized_filename,
419            gqa,
420            topology,
421            ..
422        } => GGMLLoaderBuilder::new(
423            GGMLSpecificConfig {
424                gqa,
425                prompt_chunksize: args.prompt_chunksize,
426                topology: Topology::from_option_path(topology)?,
427            },
428            args.chat_template,
429            tokenizer_json,
430            Some(tok_model_id),
431            quantized_model_id,
432            quantized_filename,
433            args.no_kv_cache,
434            args.jinja_explicit,
435        )
436        .build(),
437        ModelSelected::XLoraGGML {
438            tok_model_id,
439            tokenizer_json,
440            quantized_model_id,
441            quantized_filename,
442            xlora_model_id,
443            order,
444            tgt_non_granular_index,
445            gqa,
446            topology,
447            ..
448        } => GGMLLoaderBuilder::new(
449            GGMLSpecificConfig {
450                gqa,
451                prompt_chunksize: args.prompt_chunksize,
452                topology: Topology::from_option_path(topology)?,
453            },
454            args.chat_template,
455            tokenizer_json,
456            tok_model_id,
457            quantized_model_id,
458            quantized_filename,
459            args.no_kv_cache,
460            args.jinja_explicit,
461        )
462        .with_xlora(
463            xlora_model_id,
464            serde_json::from_reader(
465                File::open(order.clone())
466                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
467            )?,
468            args.no_kv_cache,
469            tgt_non_granular_index,
470        )
471        .build(),
472        ModelSelected::LoraGGML {
473            tok_model_id,
474            tokenizer_json,
475            quantized_model_id,
476            quantized_filename,
477            adapters_model_id,
478            order,
479            gqa,
480            topology,
481            ..
482        } => GGMLLoaderBuilder::new(
483            GGMLSpecificConfig {
484                gqa,
485                prompt_chunksize: args.prompt_chunksize,
486                topology: Topology::from_option_path(topology)?,
487            },
488            args.chat_template,
489            tokenizer_json,
490            tok_model_id,
491            quantized_model_id,
492            quantized_filename,
493            args.no_kv_cache,
494            args.jinja_explicit,
495        )
496        .with_lora(
497            adapters_model_id,
498            serde_json::from_reader(
499                File::open(order.clone())
500                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
501            )?,
502        )
503        .build(),
504        ModelSelected::VisionPlain {
505            model_id,
506            tokenizer_json,
507            arch,
508            dtype: _,
509            topology,
510            write_uqff,
511            from_uqff,
512            max_edge,
513            calibration_file,
514            max_seq_len: _,
515            max_batch_size: _,
516            max_num_images: _,
517            max_image_length: _,
518            hf_cache_path,
519            imatrix,
520        } => VisionLoaderBuilder::new(
521            VisionSpecificConfig {
522                prompt_chunksize: args.prompt_chunksize,
523                topology: Topology::from_option_path(topology)?,
524                write_uqff,
525                from_uqff: from_uqff.map(|x| {
526                    x.split(UQFF_MULTI_FILE_DELIMITER)
527                        .map(PathBuf::from_str)
528                        .map(|x| x.unwrap())
529                        .collect::<Vec<_>>()
530                }),
531                max_edge,
532                calibration_file,
533                imatrix,
534                hf_cache_path,
535            },
536            args.chat_template,
537            tokenizer_json,
538            Some(model_id),
539            args.jinja_explicit,
540        )
541        .build(arch),
542        ModelSelected::DiffusionPlain {
543            model_id,
544            arch,
545            dtype: _,
546        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
547        ModelSelected::Speech {
548            model_id,
549            dac_model_id,
550            arch,
551            ..
552        } => Box::new(SpeechLoader {
553            model_id,
554            dac_model_id,
555            arch,
556            cfg: None,
557        }),
558    };
559    Ok(loader)
560}