1use std::{
2    fs::{self, File},
3    path::PathBuf,
4    str::FromStr,
5};
6
7use mistralrs_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10    get_toml_selected_model_dtype,
11    pipeline::{
12        AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13        GGUFLoaderBuilder, GGUFSpecificConfig, NormalLoaderBuilder, NormalSpecificConfig,
14        VisionLoaderBuilder, VisionSpecificConfig,
15    },
16    toml_selector::get_toml_selected_model_device_map_params,
17    AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig, Loader, ModelDType,
18    ModelSelected, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER,
19    UQFF_MULTI_FILE_DELIMITER,
20};
21
22pub struct LoaderBuilder {
24    model: ModelSelected,
25    no_kv_cache: bool,
26    chat_template: Option<String>,
27    jinja_explicit: Option<String>,
28}
29
30impl LoaderBuilder {
31    pub fn new(model: ModelSelected) -> Self {
32        Self {
33            model,
34            no_kv_cache: false,
35            chat_template: None,
36            jinja_explicit: None,
37        }
38    }
39
40    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41        self.no_kv_cache = no_kv_cache;
42        self
43    }
44    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45        self.chat_template = chat_template;
46        self
47    }
48    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49        self.jinja_explicit = jinja_explicit;
50        self
51    }
52
53    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
54        loader_from_model_selected(self)
55    }
56}
57
58pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
59    match model {
60        ModelSelected::Plain { .. }
61        | ModelSelected::Run { .. }
62        | ModelSelected::Lora { .. }
63        | ModelSelected::GGUF { .. }
64        | ModelSelected::LoraGGUF { .. }
65        | ModelSelected::GGML { .. }
66        | ModelSelected::LoraGGML { .. }
67        | ModelSelected::Toml { .. }
68        | ModelSelected::VisionPlain { .. }
69        | ModelSelected::DiffusionPlain { .. }
70        | ModelSelected::Speech { .. }
71        | ModelSelected::Embedding { .. } => None,
72        ModelSelected::XLora {
73            tgt_non_granular_index,
74            ..
75        }
76        | ModelSelected::XLoraGGUF {
77            tgt_non_granular_index,
78            ..
79        }
80        | ModelSelected::XLoraGGML {
81            tgt_non_granular_index,
82            ..
83        } => *tgt_non_granular_index,
84        ModelSelected::MultiModel { .. } => {
85            panic!("MultiModel variant should not be used in model loading functions")
86        }
87    }
88}
89
90pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
91    match model {
92        ModelSelected::Plain { dtype, .. }
93        | ModelSelected::Lora { dtype, .. }
94        | ModelSelected::XLora { dtype, .. }
95        | ModelSelected::VisionPlain { dtype, .. }
96        | ModelSelected::DiffusionPlain { dtype, .. }
97        | ModelSelected::GGML { dtype, .. }
98        | ModelSelected::GGUF { dtype, .. }
99        | ModelSelected::XLoraGGUF { dtype, .. }
100        | ModelSelected::XLoraGGML { dtype, .. }
101        | ModelSelected::LoraGGUF { dtype, .. }
102        | ModelSelected::LoraGGML { dtype, .. }
103        | ModelSelected::Run { dtype, .. }
104        | ModelSelected::Speech { dtype, .. }
105        | ModelSelected::Embedding { dtype, .. } => Ok(*dtype),
106        ModelSelected::Toml { file } => {
107            let selector: TomlSelector = toml::from_str(
108                &fs::read_to_string(file.clone())
109                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
110            )?;
111            Ok(get_toml_selected_model_dtype(&selector))
112        }
113        ModelSelected::MultiModel { .. } => {
114            anyhow::bail!("MultiModel variant should not be used in model loading functions")
115        }
116    }
117}
118
119pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
120    match model {
121        ModelSelected::Plain {
122            max_seq_len,
123            max_batch_size,
124            ..
125        }
126        | ModelSelected::Lora {
127            max_seq_len,
128            max_batch_size,
129            ..
130        }
131        | ModelSelected::XLora {
132            max_seq_len,
133            max_batch_size,
134            ..
135        }
136        | ModelSelected::GGML {
137            max_seq_len,
138            max_batch_size,
139            ..
140        }
141        | ModelSelected::GGUF {
142            max_seq_len,
143            max_batch_size,
144            ..
145        }
146        | ModelSelected::XLoraGGUF {
147            max_seq_len,
148            max_batch_size,
149            ..
150        }
151        | ModelSelected::XLoraGGML {
152            max_seq_len,
153            max_batch_size,
154            ..
155        }
156        | ModelSelected::LoraGGUF {
157            max_seq_len,
158            max_batch_size,
159            ..
160        }
161        | ModelSelected::LoraGGML {
162            max_seq_len,
163            max_batch_size,
164            ..
165        } => Ok(AutoDeviceMapParams::Text {
166            max_seq_len: *max_seq_len,
167            max_batch_size: *max_batch_size,
168        }),
169        ModelSelected::Run {
170            max_seq_len,
171            max_batch_size,
172            max_image_length,
173            max_num_images,
174            ..
175        } => {
176            if max_num_images.is_some() || max_image_length.is_some() {
177                let max_image_length =
178                    max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
179                Ok(AutoDeviceMapParams::Vision {
180                    max_seq_len: *max_seq_len,
181                    max_batch_size: *max_batch_size,
182                    max_image_shape: (max_image_length, max_image_length),
183                    max_num_images: max_num_images
184                        .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
185                })
186            } else {
187                Ok(AutoDeviceMapParams::Text {
188                    max_seq_len: *max_seq_len,
189                    max_batch_size: *max_batch_size,
190                })
191            }
192        }
193        ModelSelected::VisionPlain {
194            max_seq_len,
195            max_batch_size,
196            max_image_length,
197            max_num_images,
198            ..
199        } => Ok(AutoDeviceMapParams::Vision {
200            max_seq_len: *max_seq_len,
201            max_batch_size: *max_batch_size,
202            max_image_shape: (*max_image_length, *max_image_length),
203            max_num_images: *max_num_images,
204        }),
205        ModelSelected::DiffusionPlain { .. }
206        | ModelSelected::Speech { .. }
207        | ModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
208        ModelSelected::Toml { file } => {
209            let selector: TomlSelector = toml::from_str(
210                &fs::read_to_string(file.clone())
211                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
212            )?;
213            get_toml_selected_model_device_map_params(&selector)
214        }
215        ModelSelected::MultiModel { .. } => {
216            anyhow::bail!("MultiModel variant should not be used in model loading functions")
217        }
218    }
219}
220
221fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
222    let loader: Box<dyn Loader> = match args.model {
223        ModelSelected::Toml { file } => {
224            let selector: TomlSelector = toml::from_str(
225                &fs::read_to_string(file.clone())
226                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
227            )?;
228            let args = TomlLoaderArgs {
229                chat_template: args.chat_template,
230                no_kv_cache: args.no_kv_cache,
231                jinja_explicit: args.jinja_explicit,
232            };
233            (selector, args).try_into()?
234        }
235        ModelSelected::Plain {
236            model_id,
237            tokenizer_json,
238            arch,
239            dtype: _,
240            topology,
241            organization,
242            write_uqff,
243            from_uqff,
244            imatrix,
245            calibration_file,
246            max_seq_len: _,
247            max_batch_size: _,
248            hf_cache_path,
249            matformer_config_path,
250            matformer_slice_name,
251        } => NormalLoaderBuilder::new(
252            NormalSpecificConfig {
253                topology: Topology::from_option_path(topology)?,
254                organization: organization.unwrap_or_default(),
255                write_uqff,
256                from_uqff: from_uqff.map(|x| {
257                    x.split(UQFF_MULTI_FILE_DELIMITER)
258                        .map(PathBuf::from_str)
259                        .map(|x| x.unwrap())
260                        .collect::<Vec<_>>()
261                }),
262                imatrix,
263                calibration_file,
264                hf_cache_path,
265                matformer_config_path,
266                matformer_slice_name,
267            },
268            args.chat_template,
269            tokenizer_json,
270            Some(model_id),
271            args.no_kv_cache,
272            args.jinja_explicit,
273        )
274        .build(arch)?,
275        ModelSelected::Run {
276            model_id,
277            tokenizer_json,
278            dtype: _,
279            topology,
280            organization,
281            write_uqff,
282            from_uqff,
283            imatrix,
284            calibration_file,
285            max_edge,
286            max_seq_len: _,
287            max_batch_size: _,
288            max_num_images: _,
289            max_image_length: _,
290            hf_cache_path,
291            matformer_config_path,
292            matformer_slice_name,
293        } => {
294            let builder = AutoLoaderBuilder::new(
295                NormalSpecificConfig {
296                    topology: Topology::from_option_path(topology.clone())?,
297                    organization: organization.unwrap_or_default(),
298                    write_uqff: write_uqff.clone(),
299                    from_uqff: from_uqff.clone().map(|x| {
300                        x.split(UQFF_MULTI_FILE_DELIMITER)
301                            .map(PathBuf::from_str)
302                            .map(|x| x.unwrap())
303                            .collect::<Vec<_>>()
304                    }),
305                    imatrix: imatrix.clone(),
306                    calibration_file: calibration_file.clone(),
307                    hf_cache_path: hf_cache_path.clone(),
308                    matformer_config_path: matformer_config_path.clone(),
309                    matformer_slice_name: matformer_slice_name.clone(),
310                },
311                VisionSpecificConfig {
312                    topology: Topology::from_option_path(topology.clone())?,
313                    write_uqff: write_uqff.clone(),
314                    from_uqff: from_uqff.clone().map(|x| {
315                        x.split(UQFF_MULTI_FILE_DELIMITER)
316                            .map(PathBuf::from_str)
317                            .map(|x| x.unwrap())
318                            .collect::<Vec<_>>()
319                    }),
320                    max_edge,
321                    calibration_file,
322                    imatrix,
323                    hf_cache_path: hf_cache_path.clone(),
324                    matformer_config_path,
325                    matformer_slice_name,
326                },
327                EmbeddingSpecificConfig {
328                    topology: Topology::from_option_path(topology)?,
329                    write_uqff,
330                    from_uqff: from_uqff.map(|x| {
331                        x.split(UQFF_MULTI_FILE_DELIMITER)
332                            .map(PathBuf::from_str)
333                            .map(|x| x.unwrap())
334                            .collect::<Vec<_>>()
335                    }),
336                    hf_cache_path: hf_cache_path.clone(),
337                },
338                args.chat_template,
339                tokenizer_json,
340                model_id,
341                args.no_kv_cache,
342                args.jinja_explicit,
343            );
344            let builder = if let Some(ref path) = hf_cache_path {
345                builder.hf_cache_path(path.clone())
346            } else {
347                builder
348            };
349            builder.build()
350        }
351        ModelSelected::VisionPlain {
352            model_id,
353            tokenizer_json,
354            arch,
355            dtype: _,
356            topology,
357            write_uqff,
358            from_uqff,
359            max_edge,
360            calibration_file,
361            max_seq_len: _,
362            max_batch_size: _,
363            max_num_images: _,
364            max_image_length: _,
365            hf_cache_path,
366            imatrix,
367            matformer_config_path,
368            matformer_slice_name,
369        } => VisionLoaderBuilder::new(
370            VisionSpecificConfig {
371                topology: Topology::from_option_path(topology)?,
372                write_uqff,
373                from_uqff: from_uqff.map(|x| {
374                    x.split(UQFF_MULTI_FILE_DELIMITER)
375                        .map(PathBuf::from_str)
376                        .map(|x| x.unwrap())
377                        .collect::<Vec<_>>()
378                }),
379                max_edge,
380                calibration_file,
381                imatrix,
382                hf_cache_path,
383                matformer_config_path,
384                matformer_slice_name,
385            },
386            args.chat_template,
387            tokenizer_json,
388            Some(model_id),
389            args.jinja_explicit,
390        )
391        .build(arch),
392        ModelSelected::DiffusionPlain {
393            model_id,
394            arch,
395            dtype: _,
396        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
397        ModelSelected::Speech {
398            model_id,
399            dac_model_id,
400            arch,
401            ..
402        } => Box::new(SpeechLoader {
403            model_id,
404            dac_model_id,
405            arch,
406            cfg: None,
407        }),
408        ModelSelected::XLora {
409            model_id,
410            xlora_model_id,
411            order,
412            tokenizer_json,
413            tgt_non_granular_index,
414            arch,
415            dtype: _,
416            topology,
417            write_uqff,
418            from_uqff,
419            max_seq_len: _,
420            max_batch_size: _,
421            hf_cache_path,
422        } => NormalLoaderBuilder::new(
423            NormalSpecificConfig {
424                topology: Topology::from_option_path(topology)?,
425                organization: Default::default(),
426                write_uqff,
427                from_uqff: from_uqff.map(|x| {
428                    x.split(UQFF_MULTI_FILE_DELIMITER)
429                        .map(PathBuf::from_str)
430                        .map(|x| x.unwrap())
431                        .collect::<Vec<_>>()
432                }),
433                imatrix: None,
434                calibration_file: None,
435                hf_cache_path,
436                matformer_config_path: None,
437                matformer_slice_name: None,
438            },
439            args.chat_template,
440            tokenizer_json,
441            model_id,
442            args.no_kv_cache,
443            args.jinja_explicit,
444        )
445        .with_xlora(
446            xlora_model_id,
447            serde_json::from_reader(
448                File::open(order.clone())
449                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
450            )?,
451            args.no_kv_cache,
452            tgt_non_granular_index,
453        )
454        .build(arch)?,
455        ModelSelected::Lora {
456            model_id,
457            tokenizer_json,
458            adapter_model_id,
459            arch,
460            dtype: _,
461            topology,
462            write_uqff,
463            from_uqff,
464            max_seq_len: _,
465            max_batch_size: _,
466            hf_cache_path,
467        } => NormalLoaderBuilder::new(
468            NormalSpecificConfig {
469                topology: Topology::from_option_path(topology)?,
470                organization: Default::default(),
471                write_uqff,
472                from_uqff: from_uqff.map(|x| {
473                    x.split(UQFF_MULTI_FILE_DELIMITER)
474                        .map(PathBuf::from_str)
475                        .map(|x| x.unwrap())
476                        .collect::<Vec<_>>()
477                }),
478                imatrix: None,
479                calibration_file: None,
480                hf_cache_path,
481                matformer_config_path: None,
482                matformer_slice_name: None,
483            },
484            args.chat_template,
485            tokenizer_json,
486            model_id,
487            args.no_kv_cache,
488            args.jinja_explicit,
489        )
490        .with_lora(
491            adapter_model_id
492                .split(MULTI_LORA_DELIMITER)
493                .map(ToString::to_string)
494                .collect(),
495        )
496        .build(arch)?,
497        ModelSelected::GGUF {
498            tok_model_id,
499            quantized_model_id,
500            quantized_filename,
501            topology,
502            ..
503        } => GGUFLoaderBuilder::new(
504            args.chat_template,
505            tok_model_id,
506            quantized_model_id,
507            quantized_filename
508                .split(GGUF_MULTI_FILE_DELIMITER)
509                .map(ToOwned::to_owned)
510                .collect::<Vec<_>>(),
511            GGUFSpecificConfig {
512                topology: Topology::from_option_path(topology)?,
513            },
514            args.no_kv_cache,
515            args.jinja_explicit,
516        )
517        .build(),
518        ModelSelected::XLoraGGUF {
519            tok_model_id,
520            quantized_model_id,
521            quantized_filename,
522            xlora_model_id,
523            order,
524            tgt_non_granular_index,
525            topology,
526            ..
527        } => GGUFLoaderBuilder::new(
528            args.chat_template,
529            tok_model_id,
530            quantized_model_id,
531            quantized_filename
532                .split(GGUF_MULTI_FILE_DELIMITER)
533                .map(ToOwned::to_owned)
534                .collect::<Vec<_>>(),
535            GGUFSpecificConfig {
536                topology: Topology::from_option_path(topology)?,
537            },
538            args.no_kv_cache,
539            args.jinja_explicit,
540        )
541        .with_xlora(
542            xlora_model_id,
543            serde_json::from_reader(
544                File::open(order.clone())
545                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
546            )?,
547            args.no_kv_cache,
548            tgt_non_granular_index,
549        )
550        .build(),
551        ModelSelected::LoraGGUF {
552            tok_model_id,
553            quantized_model_id,
554            quantized_filename,
555            adapters_model_id,
556            order,
557            topology,
558            ..
559        } => GGUFLoaderBuilder::new(
560            args.chat_template,
561            tok_model_id,
562            quantized_model_id,
563            quantized_filename
564                .split(GGUF_MULTI_FILE_DELIMITER)
565                .map(ToOwned::to_owned)
566                .collect::<Vec<_>>(),
567            GGUFSpecificConfig {
568                topology: Topology::from_option_path(topology)?,
569            },
570            args.no_kv_cache,
571            args.jinja_explicit,
572        )
573        .with_lora(
574            adapters_model_id,
575            serde_json::from_reader(
576                File::open(order.clone())
577                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
578            )?,
579        )
580        .build(),
581        ModelSelected::GGML {
582            tok_model_id,
583            tokenizer_json,
584            quantized_model_id,
585            quantized_filename,
586            gqa,
587            topology,
588            ..
589        } => GGMLLoaderBuilder::new(
590            GGMLSpecificConfig {
591                gqa,
592                topology: Topology::from_option_path(topology)?,
593            },
594            args.chat_template,
595            tokenizer_json,
596            Some(tok_model_id),
597            quantized_model_id,
598            quantized_filename,
599            args.no_kv_cache,
600            args.jinja_explicit,
601        )
602        .build(),
603        ModelSelected::XLoraGGML {
604            tok_model_id,
605            tokenizer_json,
606            quantized_model_id,
607            quantized_filename,
608            xlora_model_id,
609            order,
610            tgt_non_granular_index,
611            gqa,
612            topology,
613            ..
614        } => GGMLLoaderBuilder::new(
615            GGMLSpecificConfig {
616                gqa,
617                topology: Topology::from_option_path(topology)?,
618            },
619            args.chat_template,
620            tokenizer_json,
621            tok_model_id,
622            quantized_model_id,
623            quantized_filename,
624            args.no_kv_cache,
625            args.jinja_explicit,
626        )
627        .with_xlora(
628            xlora_model_id,
629            serde_json::from_reader(
630                File::open(order.clone())
631                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
632            )?,
633            args.no_kv_cache,
634            tgt_non_granular_index,
635        )
636        .build(),
637        ModelSelected::LoraGGML {
638            tok_model_id,
639            tokenizer_json,
640            quantized_model_id,
641            quantized_filename,
642            adapters_model_id,
643            order,
644            gqa,
645            topology,
646            ..
647        } => GGMLLoaderBuilder::new(
648            GGMLSpecificConfig {
649                gqa,
650                topology: Topology::from_option_path(topology)?,
651            },
652            args.chat_template,
653            tokenizer_json,
654            tok_model_id,
655            quantized_model_id,
656            quantized_filename,
657            args.no_kv_cache,
658            args.jinja_explicit,
659        )
660        .with_lora(
661            adapters_model_id,
662            serde_json::from_reader(
663                File::open(order.clone())
664                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
665            )?,
666        )
667        .build(),
668        ModelSelected::Embedding {
669            model_id,
670            tokenizer_json,
671            arch,
672            dtype: _,
673            topology,
674            write_uqff,
675            from_uqff,
676            hf_cache_path,
677        } => EmbeddingLoaderBuilder::new(
678            EmbeddingSpecificConfig {
679                topology: Topology::from_option_path(topology)?,
680                write_uqff,
681                from_uqff: from_uqff.map(|x| {
682                    x.split(UQFF_MULTI_FILE_DELIMITER)
683                        .map(PathBuf::from_str)
684                        .map(|x| x.unwrap())
685                        .collect::<Vec<_>>()
686                }),
687                hf_cache_path,
688            },
689            tokenizer_json,
690            Some(model_id),
691        )
692        .build(arch),
693        ModelSelected::MultiModel { .. } => {
694            anyhow::bail!("MultiModel variant should not be used in model loading functions")
695        }
696    };
697    Ok(loader)
698}