mistralrs_core/
model_loader.rs

1use std::{
2    fs::{self, File},
3    num::NonZeroUsize,
4};
5
6use mistralrs_quant::MULTI_LORA_DELIMITER;
7
8use crate::{
9    get_toml_selected_model_dtype,
10    pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
11    toml_selector::get_toml_selected_model_device_map_params,
12    AutoDeviceMapParams, DiffusionLoaderBuilder, DiffusionSpecificConfig, GGUFSpecificConfig,
13    Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, Topology,
14    VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
15};
16
17/// A builder for a loader using the selected model.
18pub struct LoaderBuilder {
19    model: ModelSelected,
20    no_kv_cache: bool,
21    chat_template: Option<String>,
22    jinja_explicit: Option<String>,
23    use_flash_attn: bool,
24    prompt_chunksize: Option<NonZeroUsize>,
25}
26
27impl LoaderBuilder {
28    pub fn new(model: ModelSelected) -> Self {
29        Self {
30            model,
31            no_kv_cache: false,
32            chat_template: None,
33            use_flash_attn: false,
34            prompt_chunksize: None,
35            jinja_explicit: None,
36        }
37    }
38
39    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
40        self.no_kv_cache = no_kv_cache;
41        self
42    }
43    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
44        self.chat_template = chat_template;
45        self
46    }
47    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
48        self.jinja_explicit = jinja_explicit;
49        self
50    }
51    pub fn with_use_flash_attn(mut self, use_flash_attn: bool) -> Self {
52        self.use_flash_attn = use_flash_attn;
53        self
54    }
55    pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
56        self.prompt_chunksize = prompt_chunksize;
57        self
58    }
59
60    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
61        loader_from_model_selected(self)
62    }
63}
64
65pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
66    match model {
67        ModelSelected::Plain { .. }
68        | ModelSelected::Lora { .. }
69        | ModelSelected::GGUF { .. }
70        | ModelSelected::LoraGGUF { .. }
71        | ModelSelected::GGML { .. }
72        | ModelSelected::LoraGGML { .. }
73        | ModelSelected::Toml { .. }
74        | ModelSelected::VisionPlain { .. }
75        | ModelSelected::DiffusionPlain { .. } => None,
76        ModelSelected::XLora {
77            tgt_non_granular_index,
78            ..
79        }
80        | ModelSelected::XLoraGGUF {
81            tgt_non_granular_index,
82            ..
83        }
84        | ModelSelected::XLoraGGML {
85            tgt_non_granular_index,
86            ..
87        } => *tgt_non_granular_index,
88    }
89}
90
91pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
92    match model {
93        ModelSelected::Plain { dtype, .. }
94        | ModelSelected::Lora { dtype, .. }
95        | ModelSelected::XLora { dtype, .. }
96        | ModelSelected::VisionPlain { dtype, .. }
97        | ModelSelected::DiffusionPlain { dtype, .. }
98        | ModelSelected::GGML { dtype, .. }
99        | ModelSelected::GGUF { dtype, .. }
100        | ModelSelected::XLoraGGUF { dtype, .. }
101        | ModelSelected::XLoraGGML { dtype, .. }
102        | ModelSelected::LoraGGUF { dtype, .. }
103        | ModelSelected::LoraGGML { dtype, .. } => Ok(*dtype),
104        ModelSelected::Toml { file } => {
105            let selector: TomlSelector = toml::from_str(
106                &fs::read_to_string(file.clone())
107                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
108            )?;
109            Ok(get_toml_selected_model_dtype(&selector))
110        }
111    }
112}
113
114pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
115    match model {
116        ModelSelected::Plain {
117            max_seq_len,
118            max_batch_size,
119            ..
120        }
121        | ModelSelected::Lora {
122            max_seq_len,
123            max_batch_size,
124            ..
125        }
126        | ModelSelected::XLora {
127            max_seq_len,
128            max_batch_size,
129            ..
130        }
131        | ModelSelected::GGML {
132            max_seq_len,
133            max_batch_size,
134            ..
135        }
136        | ModelSelected::GGUF {
137            max_seq_len,
138            max_batch_size,
139            ..
140        }
141        | ModelSelected::XLoraGGUF {
142            max_seq_len,
143            max_batch_size,
144            ..
145        }
146        | ModelSelected::XLoraGGML {
147            max_seq_len,
148            max_batch_size,
149            ..
150        }
151        | ModelSelected::LoraGGUF {
152            max_seq_len,
153            max_batch_size,
154            ..
155        }
156        | ModelSelected::LoraGGML {
157            max_seq_len,
158            max_batch_size,
159            ..
160        } => Ok(AutoDeviceMapParams::Text {
161            max_seq_len: *max_seq_len,
162            max_batch_size: *max_batch_size,
163        }),
164        ModelSelected::VisionPlain {
165            max_seq_len,
166            max_batch_size,
167            max_image_length,
168            max_num_images,
169            ..
170        } => Ok(AutoDeviceMapParams::Vision {
171            max_seq_len: *max_seq_len,
172            max_batch_size: *max_batch_size,
173            max_image_shape: (*max_image_length, *max_image_length),
174            max_num_images: *max_num_images,
175        }),
176        ModelSelected::DiffusionPlain { .. } => Ok(AutoDeviceMapParams::default_text()),
177        ModelSelected::Toml { file } => {
178            let selector: TomlSelector = toml::from_str(
179                &fs::read_to_string(file.clone())
180                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
181            )?;
182            get_toml_selected_model_device_map_params(&selector)
183        }
184    }
185}
186
187fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
188    let use_flash_attn = args.use_flash_attn;
189    let loader: Box<dyn Loader> = match args.model {
190        ModelSelected::Toml { file } => {
191            let selector: TomlSelector = toml::from_str(
192                &fs::read_to_string(file.clone())
193                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
194            )?;
195            let args = TomlLoaderArgs {
196                use_flash_attn,
197                chat_template: args.chat_template,
198                no_kv_cache: args.no_kv_cache,
199                prompt_chunksize: args.prompt_chunksize,
200                jinja_explicit: args.jinja_explicit,
201            };
202            (selector, args).try_into()?
203        }
204        ModelSelected::Plain {
205            model_id,
206            tokenizer_json,
207            arch,
208            dtype: _,
209            topology,
210            organization,
211            write_uqff,
212            from_uqff,
213            imatrix,
214            calibration_file,
215            max_seq_len: _,
216            max_batch_size: _,
217            hf_cache_path,
218        } => NormalLoaderBuilder::new(
219            NormalSpecificConfig {
220                use_flash_attn,
221                prompt_chunksize: args.prompt_chunksize,
222                topology: Topology::from_option_path(topology)?,
223                organization: organization.unwrap_or_default(),
224                write_uqff,
225                from_uqff,
226                imatrix,
227                calibration_file,
228                hf_cache_path,
229            },
230            args.chat_template,
231            tokenizer_json,
232            Some(model_id),
233            args.no_kv_cache,
234            args.jinja_explicit,
235        )
236        .build(arch)?,
237        ModelSelected::XLora {
238            model_id,
239            xlora_model_id,
240            order,
241            tokenizer_json,
242            tgt_non_granular_index,
243            arch,
244            dtype: _,
245            topology,
246            write_uqff,
247            from_uqff,
248            max_seq_len: _,
249            max_batch_size: _,
250            hf_cache_path,
251        } => NormalLoaderBuilder::new(
252            NormalSpecificConfig {
253                use_flash_attn,
254                prompt_chunksize: args.prompt_chunksize,
255                topology: Topology::from_option_path(topology)?,
256                organization: Default::default(),
257                write_uqff,
258                from_uqff,
259                imatrix: None,
260                calibration_file: None,
261                hf_cache_path,
262            },
263            args.chat_template,
264            tokenizer_json,
265            model_id,
266            args.no_kv_cache,
267            args.jinja_explicit,
268        )
269        .with_xlora(
270            xlora_model_id,
271            serde_json::from_reader(
272                File::open(order.clone())
273                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
274            )?,
275            args.no_kv_cache,
276            tgt_non_granular_index,
277        )
278        .build(arch)?,
279        ModelSelected::Lora {
280            model_id,
281            tokenizer_json,
282            adapter_model_id,
283            arch,
284            dtype: _,
285            topology,
286            write_uqff,
287            from_uqff,
288            max_seq_len: _,
289            max_batch_size: _,
290            hf_cache_path,
291        } => NormalLoaderBuilder::new(
292            NormalSpecificConfig {
293                use_flash_attn,
294                prompt_chunksize: args.prompt_chunksize,
295                topology: Topology::from_option_path(topology)?,
296                organization: Default::default(),
297                write_uqff,
298                from_uqff,
299                imatrix: None,
300                calibration_file: None,
301                hf_cache_path,
302            },
303            args.chat_template,
304            tokenizer_json,
305            model_id,
306            args.no_kv_cache,
307            args.jinja_explicit,
308        )
309        .with_lora(
310            adapter_model_id
311                .split(MULTI_LORA_DELIMITER)
312                .map(ToString::to_string)
313                .collect(),
314        )
315        .build(arch)?,
316        ModelSelected::GGUF {
317            tok_model_id,
318            quantized_model_id,
319            quantized_filename,
320            topology,
321            ..
322        } => GGUFLoaderBuilder::new(
323            args.chat_template,
324            tok_model_id,
325            quantized_model_id,
326            quantized_filename
327                .split(GGUF_MULTI_FILE_DELIMITER)
328                .map(ToOwned::to_owned)
329                .collect::<Vec<_>>(),
330            GGUFSpecificConfig {
331                prompt_chunksize: args.prompt_chunksize,
332                topology: Topology::from_option_path(topology)?,
333            },
334            args.no_kv_cache,
335            args.jinja_explicit,
336        )
337        .build(),
338        ModelSelected::XLoraGGUF {
339            tok_model_id,
340            quantized_model_id,
341            quantized_filename,
342            xlora_model_id,
343            order,
344            tgt_non_granular_index,
345            topology,
346            ..
347        } => GGUFLoaderBuilder::new(
348            args.chat_template,
349            tok_model_id,
350            quantized_model_id,
351            quantized_filename
352                .split(GGUF_MULTI_FILE_DELIMITER)
353                .map(ToOwned::to_owned)
354                .collect::<Vec<_>>(),
355            GGUFSpecificConfig {
356                prompt_chunksize: args.prompt_chunksize,
357                topology: Topology::from_option_path(topology)?,
358            },
359            args.no_kv_cache,
360            args.jinja_explicit,
361        )
362        .with_xlora(
363            xlora_model_id,
364            serde_json::from_reader(
365                File::open(order.clone())
366                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
367            )?,
368            args.no_kv_cache,
369            tgt_non_granular_index,
370        )
371        .build(),
372        ModelSelected::LoraGGUF {
373            tok_model_id,
374            quantized_model_id,
375            quantized_filename,
376            adapters_model_id,
377            order,
378            topology,
379            ..
380        } => GGUFLoaderBuilder::new(
381            args.chat_template,
382            tok_model_id,
383            quantized_model_id,
384            quantized_filename
385                .split(GGUF_MULTI_FILE_DELIMITER)
386                .map(ToOwned::to_owned)
387                .collect::<Vec<_>>(),
388            GGUFSpecificConfig {
389                prompt_chunksize: args.prompt_chunksize,
390                topology: Topology::from_option_path(topology)?,
391            },
392            args.no_kv_cache,
393            args.jinja_explicit,
394        )
395        .with_lora(
396            adapters_model_id,
397            serde_json::from_reader(
398                File::open(order.clone())
399                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
400            )?,
401        )
402        .build(),
403        ModelSelected::GGML {
404            tok_model_id,
405            tokenizer_json,
406            quantized_model_id,
407            quantized_filename,
408            gqa,
409            topology,
410            ..
411        } => GGMLLoaderBuilder::new(
412            GGMLSpecificConfig {
413                gqa,
414                prompt_chunksize: args.prompt_chunksize,
415                topology: Topology::from_option_path(topology)?,
416            },
417            args.chat_template,
418            tokenizer_json,
419            Some(tok_model_id),
420            quantized_model_id,
421            quantized_filename,
422            args.no_kv_cache,
423            args.jinja_explicit,
424        )
425        .build(),
426        ModelSelected::XLoraGGML {
427            tok_model_id,
428            tokenizer_json,
429            quantized_model_id,
430            quantized_filename,
431            xlora_model_id,
432            order,
433            tgt_non_granular_index,
434            gqa,
435            topology,
436            ..
437        } => GGMLLoaderBuilder::new(
438            GGMLSpecificConfig {
439                gqa,
440                prompt_chunksize: args.prompt_chunksize,
441                topology: Topology::from_option_path(topology)?,
442            },
443            args.chat_template,
444            tokenizer_json,
445            tok_model_id,
446            quantized_model_id,
447            quantized_filename,
448            args.no_kv_cache,
449            args.jinja_explicit,
450        )
451        .with_xlora(
452            xlora_model_id,
453            serde_json::from_reader(
454                File::open(order.clone())
455                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
456            )?,
457            args.no_kv_cache,
458            tgt_non_granular_index,
459        )
460        .build(),
461        ModelSelected::LoraGGML {
462            tok_model_id,
463            tokenizer_json,
464            quantized_model_id,
465            quantized_filename,
466            adapters_model_id,
467            order,
468            gqa,
469            topology,
470            ..
471        } => GGMLLoaderBuilder::new(
472            GGMLSpecificConfig {
473                gqa,
474                prompt_chunksize: args.prompt_chunksize,
475                topology: Topology::from_option_path(topology)?,
476            },
477            args.chat_template,
478            tokenizer_json,
479            tok_model_id,
480            quantized_model_id,
481            quantized_filename,
482            args.no_kv_cache,
483            args.jinja_explicit,
484        )
485        .with_lora(
486            adapters_model_id,
487            serde_json::from_reader(
488                File::open(order.clone())
489                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
490            )?,
491        )
492        .build(),
493        ModelSelected::VisionPlain {
494            model_id,
495            tokenizer_json,
496            arch,
497            dtype: _,
498            topology,
499            write_uqff,
500            from_uqff,
501            max_edge,
502            calibration_file,
503            max_seq_len: _,
504            max_batch_size: _,
505            max_num_images: _,
506            max_image_length: _,
507            hf_cache_path,
508            imatrix,
509        } => VisionLoaderBuilder::new(
510            VisionSpecificConfig {
511                use_flash_attn,
512                prompt_chunksize: args.prompt_chunksize,
513                topology: Topology::from_option_path(topology)?,
514                write_uqff,
515                from_uqff,
516                max_edge,
517                calibration_file,
518                imatrix,
519                hf_cache_path,
520            },
521            args.chat_template,
522            tokenizer_json,
523            Some(model_id),
524            args.jinja_explicit,
525        )
526        .build(arch),
527        ModelSelected::DiffusionPlain {
528            model_id,
529            arch,
530            dtype: _,
531        } => {
532            DiffusionLoaderBuilder::new(DiffusionSpecificConfig { use_flash_attn }, Some(model_id))
533                .build(arch)
534        }
535    };
536    Ok(loader)
537}