mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    num::NonZeroUsize,
4    path::PathBuf,
5    str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11    get_toml_selected_model_dtype,
12    pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
13    toml_selector::get_toml_selected_model_device_map_params,
14    AutoDeviceMapParams, DiffusionLoaderBuilder, DiffusionSpecificConfig, GGUFSpecificConfig,
15    Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, Topology,
16    VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
17    UQFF_MULTI_FILE_DELIMITER,
18};
19
20/// A builder for a loader using the selected model.
21pub struct LoaderBuilder {
22    model: ModelSelected,
23    no_kv_cache: bool,
24    chat_template: Option<String>,
25    jinja_explicit: Option<String>,
26    use_flash_attn: bool,
27    prompt_chunksize: Option<NonZeroUsize>,
28}
29
30impl LoaderBuilder {
31    pub fn new(model: ModelSelected) -> Self {
32        Self {
33            model,
34            no_kv_cache: false,
35            chat_template: None,
36            use_flash_attn: false,
37            prompt_chunksize: None,
38            jinja_explicit: None,
39        }
40    }
41
42    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43        self.no_kv_cache = no_kv_cache;
44        self
45    }
46    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47        self.chat_template = chat_template;
48        self
49    }
50    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51        self.jinja_explicit = jinja_explicit;
52        self
53    }
54    pub fn with_use_flash_attn(mut self, use_flash_attn: bool) -> Self {
55        self.use_flash_attn = use_flash_attn;
56        self
57    }
58    pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
59        self.prompt_chunksize = prompt_chunksize;
60        self
61    }
62
63    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
64        loader_from_model_selected(self)
65    }
66}
67
68pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
69    match model {
70        ModelSelected::Plain { .. }
71        | ModelSelected::Lora { .. }
72        | ModelSelected::GGUF { .. }
73        | ModelSelected::LoraGGUF { .. }
74        | ModelSelected::GGML { .. }
75        | ModelSelected::LoraGGML { .. }
76        | ModelSelected::Toml { .. }
77        | ModelSelected::VisionPlain { .. }
78        | ModelSelected::DiffusionPlain { .. } => None,
79        ModelSelected::XLora {
80            tgt_non_granular_index,
81            ..
82        }
83        | ModelSelected::XLoraGGUF {
84            tgt_non_granular_index,
85            ..
86        }
87        | ModelSelected::XLoraGGML {
88            tgt_non_granular_index,
89            ..
90        } => *tgt_non_granular_index,
91    }
92}
93
94pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
95    match model {
96        ModelSelected::Plain { dtype, .. }
97        | ModelSelected::Lora { dtype, .. }
98        | ModelSelected::XLora { dtype, .. }
99        | ModelSelected::VisionPlain { dtype, .. }
100        | ModelSelected::DiffusionPlain { dtype, .. }
101        | ModelSelected::GGML { dtype, .. }
102        | ModelSelected::GGUF { dtype, .. }
103        | ModelSelected::XLoraGGUF { dtype, .. }
104        | ModelSelected::XLoraGGML { dtype, .. }
105        | ModelSelected::LoraGGUF { dtype, .. }
106        | ModelSelected::LoraGGML { dtype, .. } => Ok(*dtype),
107        ModelSelected::Toml { file } => {
108            let selector: TomlSelector = toml::from_str(
109                &fs::read_to_string(file.clone())
110                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
111            )?;
112            Ok(get_toml_selected_model_dtype(&selector))
113        }
114    }
115}
116
117pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
118    match model {
119        ModelSelected::Plain {
120            max_seq_len,
121            max_batch_size,
122            ..
123        }
124        | ModelSelected::Lora {
125            max_seq_len,
126            max_batch_size,
127            ..
128        }
129        | ModelSelected::XLora {
130            max_seq_len,
131            max_batch_size,
132            ..
133        }
134        | ModelSelected::GGML {
135            max_seq_len,
136            max_batch_size,
137            ..
138        }
139        | ModelSelected::GGUF {
140            max_seq_len,
141            max_batch_size,
142            ..
143        }
144        | ModelSelected::XLoraGGUF {
145            max_seq_len,
146            max_batch_size,
147            ..
148        }
149        | ModelSelected::XLoraGGML {
150            max_seq_len,
151            max_batch_size,
152            ..
153        }
154        | ModelSelected::LoraGGUF {
155            max_seq_len,
156            max_batch_size,
157            ..
158        }
159        | ModelSelected::LoraGGML {
160            max_seq_len,
161            max_batch_size,
162            ..
163        } => Ok(AutoDeviceMapParams::Text {
164            max_seq_len: *max_seq_len,
165            max_batch_size: *max_batch_size,
166        }),
167        ModelSelected::VisionPlain {
168            max_seq_len,
169            max_batch_size,
170            max_image_length,
171            max_num_images,
172            ..
173        } => Ok(AutoDeviceMapParams::Vision {
174            max_seq_len: *max_seq_len,
175            max_batch_size: *max_batch_size,
176            max_image_shape: (*max_image_length, *max_image_length),
177            max_num_images: *max_num_images,
178        }),
179        ModelSelected::DiffusionPlain { .. } => Ok(AutoDeviceMapParams::default_text()),
180        ModelSelected::Toml { file } => {
181            let selector: TomlSelector = toml::from_str(
182                &fs::read_to_string(file.clone())
183                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
184            )?;
185            get_toml_selected_model_device_map_params(&selector)
186        }
187    }
188}
189
190fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
191    let use_flash_attn = args.use_flash_attn;
192    let loader: Box<dyn Loader> = match args.model {
193        ModelSelected::Toml { file } => {
194            let selector: TomlSelector = toml::from_str(
195                &fs::read_to_string(file.clone())
196                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
197            )?;
198            let args = TomlLoaderArgs {
199                use_flash_attn,
200                chat_template: args.chat_template,
201                no_kv_cache: args.no_kv_cache,
202                prompt_chunksize: args.prompt_chunksize,
203                jinja_explicit: args.jinja_explicit,
204            };
205            (selector, args).try_into()?
206        }
207        ModelSelected::Plain {
208            model_id,
209            tokenizer_json,
210            arch,
211            dtype: _,
212            topology,
213            organization,
214            write_uqff,
215            from_uqff,
216            imatrix,
217            calibration_file,
218            max_seq_len: _,
219            max_batch_size: _,
220            hf_cache_path,
221        } => NormalLoaderBuilder::new(
222            NormalSpecificConfig {
223                use_flash_attn,
224                prompt_chunksize: args.prompt_chunksize,
225                topology: Topology::from_option_path(topology)?,
226                organization: organization.unwrap_or_default(),
227                write_uqff,
228                from_uqff: from_uqff.map(|x| {
229                    x.split(UQFF_MULTI_FILE_DELIMITER)
230                        .map(PathBuf::from_str)
231                        .map(|x| x.unwrap())
232                        .collect::<Vec<_>>()
233                }),
234                imatrix,
235                calibration_file,
236                hf_cache_path,
237            },
238            args.chat_template,
239            tokenizer_json,
240            Some(model_id),
241            args.no_kv_cache,
242            args.jinja_explicit,
243        )
244        .build(arch)?,
245        ModelSelected::XLora {
246            model_id,
247            xlora_model_id,
248            order,
249            tokenizer_json,
250            tgt_non_granular_index,
251            arch,
252            dtype: _,
253            topology,
254            write_uqff,
255            from_uqff,
256            max_seq_len: _,
257            max_batch_size: _,
258            hf_cache_path,
259        } => NormalLoaderBuilder::new(
260            NormalSpecificConfig {
261                use_flash_attn,
262                prompt_chunksize: args.prompt_chunksize,
263                topology: Topology::from_option_path(topology)?,
264                organization: Default::default(),
265                write_uqff,
266                from_uqff: from_uqff.map(|x| {
267                    x.split(UQFF_MULTI_FILE_DELIMITER)
268                        .map(PathBuf::from_str)
269                        .map(|x| x.unwrap())
270                        .collect::<Vec<_>>()
271                }),
272                imatrix: None,
273                calibration_file: None,
274                hf_cache_path,
275            },
276            args.chat_template,
277            tokenizer_json,
278            model_id,
279            args.no_kv_cache,
280            args.jinja_explicit,
281        )
282        .with_xlora(
283            xlora_model_id,
284            serde_json::from_reader(
285                File::open(order.clone())
286                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
287            )?,
288            args.no_kv_cache,
289            tgt_non_granular_index,
290        )
291        .build(arch)?,
292        ModelSelected::Lora {
293            model_id,
294            tokenizer_json,
295            adapter_model_id,
296            arch,
297            dtype: _,
298            topology,
299            write_uqff,
300            from_uqff,
301            max_seq_len: _,
302            max_batch_size: _,
303            hf_cache_path,
304        } => NormalLoaderBuilder::new(
305            NormalSpecificConfig {
306                use_flash_attn,
307                prompt_chunksize: args.prompt_chunksize,
308                topology: Topology::from_option_path(topology)?,
309                organization: Default::default(),
310                write_uqff,
311                from_uqff: from_uqff.map(|x| {
312                    x.split(UQFF_MULTI_FILE_DELIMITER)
313                        .map(PathBuf::from_str)
314                        .map(|x| x.unwrap())
315                        .collect::<Vec<_>>()
316                }),
317                imatrix: None,
318                calibration_file: None,
319                hf_cache_path,
320            },
321            args.chat_template,
322            tokenizer_json,
323            model_id,
324            args.no_kv_cache,
325            args.jinja_explicit,
326        )
327        .with_lora(
328            adapter_model_id
329                .split(MULTI_LORA_DELIMITER)
330                .map(ToString::to_string)
331                .collect(),
332        )
333        .build(arch)?,
334        ModelSelected::GGUF {
335            tok_model_id,
336            quantized_model_id,
337            quantized_filename,
338            topology,
339            ..
340        } => GGUFLoaderBuilder::new(
341            args.chat_template,
342            tok_model_id,
343            quantized_model_id,
344            quantized_filename
345                .split(GGUF_MULTI_FILE_DELIMITER)
346                .map(ToOwned::to_owned)
347                .collect::<Vec<_>>(),
348            GGUFSpecificConfig {
349                prompt_chunksize: args.prompt_chunksize,
350                topology: Topology::from_option_path(topology)?,
351            },
352            args.no_kv_cache,
353            args.jinja_explicit,
354        )
355        .build(),
356        ModelSelected::XLoraGGUF {
357            tok_model_id,
358            quantized_model_id,
359            quantized_filename,
360            xlora_model_id,
361            order,
362            tgt_non_granular_index,
363            topology,
364            ..
365        } => GGUFLoaderBuilder::new(
366            args.chat_template,
367            tok_model_id,
368            quantized_model_id,
369            quantized_filename
370                .split(GGUF_MULTI_FILE_DELIMITER)
371                .map(ToOwned::to_owned)
372                .collect::<Vec<_>>(),
373            GGUFSpecificConfig {
374                prompt_chunksize: args.prompt_chunksize,
375                topology: Topology::from_option_path(topology)?,
376            },
377            args.no_kv_cache,
378            args.jinja_explicit,
379        )
380        .with_xlora(
381            xlora_model_id,
382            serde_json::from_reader(
383                File::open(order.clone())
384                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
385            )?,
386            args.no_kv_cache,
387            tgt_non_granular_index,
388        )
389        .build(),
390        ModelSelected::LoraGGUF {
391            tok_model_id,
392            quantized_model_id,
393            quantized_filename,
394            adapters_model_id,
395            order,
396            topology,
397            ..
398        } => GGUFLoaderBuilder::new(
399            args.chat_template,
400            tok_model_id,
401            quantized_model_id,
402            quantized_filename
403                .split(GGUF_MULTI_FILE_DELIMITER)
404                .map(ToOwned::to_owned)
405                .collect::<Vec<_>>(),
406            GGUFSpecificConfig {
407                prompt_chunksize: args.prompt_chunksize,
408                topology: Topology::from_option_path(topology)?,
409            },
410            args.no_kv_cache,
411            args.jinja_explicit,
412        )
413        .with_lora(
414            adapters_model_id,
415            serde_json::from_reader(
416                File::open(order.clone())
417                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
418            )?,
419        )
420        .build(),
421        ModelSelected::GGML {
422            tok_model_id,
423            tokenizer_json,
424            quantized_model_id,
425            quantized_filename,
426            gqa,
427            topology,
428            ..
429        } => GGMLLoaderBuilder::new(
430            GGMLSpecificConfig {
431                gqa,
432                prompt_chunksize: args.prompt_chunksize,
433                topology: Topology::from_option_path(topology)?,
434            },
435            args.chat_template,
436            tokenizer_json,
437            Some(tok_model_id),
438            quantized_model_id,
439            quantized_filename,
440            args.no_kv_cache,
441            args.jinja_explicit,
442        )
443        .build(),
444        ModelSelected::XLoraGGML {
445            tok_model_id,
446            tokenizer_json,
447            quantized_model_id,
448            quantized_filename,
449            xlora_model_id,
450            order,
451            tgt_non_granular_index,
452            gqa,
453            topology,
454            ..
455        } => GGMLLoaderBuilder::new(
456            GGMLSpecificConfig {
457                gqa,
458                prompt_chunksize: args.prompt_chunksize,
459                topology: Topology::from_option_path(topology)?,
460            },
461            args.chat_template,
462            tokenizer_json,
463            tok_model_id,
464            quantized_model_id,
465            quantized_filename,
466            args.no_kv_cache,
467            args.jinja_explicit,
468        )
469        .with_xlora(
470            xlora_model_id,
471            serde_json::from_reader(
472                File::open(order.clone())
473                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
474            )?,
475            args.no_kv_cache,
476            tgt_non_granular_index,
477        )
478        .build(),
479        ModelSelected::LoraGGML {
480            tok_model_id,
481            tokenizer_json,
482            quantized_model_id,
483            quantized_filename,
484            adapters_model_id,
485            order,
486            gqa,
487            topology,
488            ..
489        } => GGMLLoaderBuilder::new(
490            GGMLSpecificConfig {
491                gqa,
492                prompt_chunksize: args.prompt_chunksize,
493                topology: Topology::from_option_path(topology)?,
494            },
495            args.chat_template,
496            tokenizer_json,
497            tok_model_id,
498            quantized_model_id,
499            quantized_filename,
500            args.no_kv_cache,
501            args.jinja_explicit,
502        )
503        .with_lora(
504            adapters_model_id,
505            serde_json::from_reader(
506                File::open(order.clone())
507                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
508            )?,
509        )
510        .build(),
511        ModelSelected::VisionPlain {
512            model_id,
513            tokenizer_json,
514            arch,
515            dtype: _,
516            topology,
517            write_uqff,
518            from_uqff,
519            max_edge,
520            calibration_file,
521            max_seq_len: _,
522            max_batch_size: _,
523            max_num_images: _,
524            max_image_length: _,
525            hf_cache_path,
526            imatrix,
527        } => VisionLoaderBuilder::new(
528            VisionSpecificConfig {
529                use_flash_attn,
530                prompt_chunksize: args.prompt_chunksize,
531                topology: Topology::from_option_path(topology)?,
532                write_uqff,
533                from_uqff: from_uqff.map(|x| {
534                    x.split(UQFF_MULTI_FILE_DELIMITER)
535                        .map(PathBuf::from_str)
536                        .map(|x| x.unwrap())
537                        .collect::<Vec<_>>()
538                }),
539                max_edge,
540                calibration_file,
541                imatrix,
542                hf_cache_path,
543            },
544            args.chat_template,
545            tokenizer_json,
546            Some(model_id),
547            args.jinja_explicit,
548        )
549        .build(arch),
550        ModelSelected::DiffusionPlain {
551            model_id,
552            arch,
553            dtype: _,
554        } => {
555            DiffusionLoaderBuilder::new(DiffusionSpecificConfig { use_flash_attn }, Some(model_id))
556                .build(arch)
557        }
558    };
559    Ok(loader)
560}