mistralrs_core/pipeline/
macros.rs

1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4    ($api:expr, $model_id:expr, $should_panic:expr) => {
5        if std::path::Path::new($model_id).exists() {
6            let listing = std::fs::read_dir($model_id);
7            if listing.is_err() {
8                panic!("Cannot list directory {:?}", $model_id)
9            }
10            let listing = listing.unwrap();
11            listing
12                .into_iter()
13                .map(|s| {
14                    s.unwrap()
15                        .path()
16                        .file_name()
17                        .unwrap() // Should never terminate in `..`
18                        .to_str()
19                        .expect("Could not convert to str")
20                        .to_string()
21                })
22                .collect::<Vec<String>>()
23                .into_iter()
24        } else {
25            let sanitized_id = std::path::Path::new($model_id)
26                .display()
27                .to_string()
28                .replace("/", "-");
29
30            let home_folder = if dirs::home_dir().is_some() {
31                let mut path = dirs::home_dir().unwrap();
32                path.push(".cache/huggingface/hub/");
33                if !path.exists() {
34                    let _ = std::fs::create_dir_all(&path);
35                }
36                path
37            } else {
38                "./".into()
39            };
40
41            let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
42                .map(std::path::PathBuf::from)
43                .unwrap_or(home_folder.into());
44            let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
45            if std::path::Path::new(&cache_file).exists() {
46                use std::io::Read;
47                // Read from cache
48                let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
49                let mut contents = String::new();
50                file.read_to_string(&mut contents)
51                    .expect("Could not read cache file");
52                let cache: $crate::pipeline::FileListCache =
53                    serde_json::from_str(&contents).expect("Could not parse cache JSON");
54                tracing::info!("Read from cache file {:?}", cache_file);
55                cache.files.into_iter()
56            } else {
57                $api.info()
58                    .map(|repo| {
59                        let files: Vec<String> = repo
60                            .siblings
61                            .iter()
62                            .map(|x| x.rfilename.clone())
63                            .collect::<Vec<String>>();
64                        // Save to cache
65                        let cache = $crate::pipeline::FileListCache {
66                            files: files.clone(),
67                        };
68                        let json = serde_json::to_string_pretty(&cache)
69                            .expect("Could not serialize cache");
70                        let ret = std::fs::write(&cache_file, json);
71                        tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
72                        files
73                    })
74                    .unwrap_or_else(|e| {
75                        if $should_panic {
76                            panic!("Could not get directory listing from API: {:?}", e)
77                        } else {
78                            tracing::warn!("Could not get directory listing from API: {:?}", e);
79                            Vec::<String>::new()
80                        }
81                    })
82                    .into_iter()
83            }
84        }
85    };
86}
87
88#[doc(hidden)]
89#[macro_export]
90macro_rules! api_get_file {
91    ($api:expr, $file:expr, $model_id:expr) => {
92        if std::path::Path::new($model_id).exists() {
93            let path = $model_id.join($file);
94            if !path.exists() {
95                panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
96            }
97            info!("Loading `{}` locally at `{}`", &$file, path.display());
98            path
99        } else {
100            $api.get($file)
101                .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
102        }
103    };
104}
105
106#[doc(hidden)]
107#[macro_export]
108macro_rules! get_paths {
109    (
110        $path_name:ident,
111        $token_source:expr,
112        $revision:expr,
113        $this:expr,
114        $quantized_model_id:expr,
115        $quantized_filename:expr,
116        $silent:expr,
117        $loading_uqff:expr
118    ) => {{
119        let api = {
120            use $crate::GLOBAL_HF_CACHE;
121            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
122            let mut api = ApiBuilder::from_cache(cache)
123                .with_progress(!$silent)
124                .with_token(get_token($token_source)?);
125            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
126                api = api.with_cache_dir(x.into());
127            }
128            api.build()?
129        };
130        let revision = $revision.unwrap_or("main".to_string());
131        let api = api.repo(Repo::with_revision(
132            $this.model_id.clone(),
133            RepoType::Model,
134            revision.clone(),
135        ));
136        let model_id = std::path::Path::new(&$this.model_id);
137        let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
138            info!("Using tokenizer.json at `{p}`");
139            PathBuf::from_str(p)?
140        } else {
141            info!("Loading `tokenizer.json` at `{}`", $this.model_id);
142            $crate::api_get_file!(api, "tokenizer.json", model_id)
143        };
144        info!("Loading `config.json` at `{}`", $this.model_id);
145        let config_filename = $crate::api_get_file!(api, "config.json", model_id);
146        let filenames = get_model_paths(
147            revision.clone(),
148            &$token_source,
149            $quantized_model_id.as_ref(),
150            $quantized_filename.as_ref(),
151            &api,
152            &model_id,
153            $loading_uqff,
154        )?;
155        let adapter_paths = get_xlora_paths(
156            $this.model_id.clone(),
157            $this.xlora_model_id.as_ref(),
158            $this.lora_adapter_ids.as_ref(),
159            &$token_source,
160            revision.clone(),
161            $this.xlora_order.as_ref(),
162        )?;
163        let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();
164
165        let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
166            info!("Loading `generation_config.json` at `{}`", $this.model_id);
167            Some($crate::api_get_file!(
168                api,
169                "generation_config.json",
170                model_id
171            ))
172        } else {
173            None
174        };
175        let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
176            info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
177            Some($crate::api_get_file!(
178                api,
179                "preprocessor_config.json",
180                model_id
181            ))
182        } else {
183            None
184        };
185        let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
186            info!("Loading `processor_config.json` at `{}`", $this.model_id);
187            Some($crate::api_get_file!(
188                api,
189                "processor_config.json",
190                model_id
191            ))
192        } else {
193            None
194        };
195        let template_filename = if let Some(ref p) = $this.chat_template {
196            info!("Using chat template file at `{p}`");
197            Some(PathBuf::from_str(p)?)
198        } else if dir_list.contains(&"chat_template.jinja".to_string()) {
199            info!("Loading `chat_template.jinja` at `{}`", $this.model_id);
200            Some($crate::api_get_file!(api, "chat_template.jinja", model_id))
201        } else {
202            info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
203            Some($crate::api_get_file!(
204                api,
205                "tokenizer_config.json",
206                model_id
207            ))
208        };
209        let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
210            info!("Loading `chat_template.json` at `{}`", $this.model_id);
211            Some($crate::api_get_file!(api, "chat_template.json", model_id))
212        } else {
213            None
214        };
215        Ok(Box::new($path_name {
216            tokenizer_filename,
217            config_filename,
218            filenames,
219            adapter_paths,
220            template_filename,
221            gen_conf,
222            preprocessor_config,
223            processor_config,
224            chat_template_json_filename,
225        }))
226    }};
227}
228
229#[doc(hidden)]
230#[macro_export]
231macro_rules! get_embedding_paths {
232    (
233        $path_name:ident,
234        $token_source:expr,
235        $revision:expr,
236        $this:expr,
237        $quantized_model_id:expr,
238        $quantized_filename:expr,
239        $silent:expr,
240        $loading_uqff:expr
241    ) => {{
242        let api = {
243            use $crate::GLOBAL_HF_CACHE;
244            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
245            let mut api = ApiBuilder::from_cache(cache)
246                .with_progress(!$silent)
247                .with_token(get_token($token_source)?);
248            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
249                api = api.with_cache_dir(x.into());
250            }
251            api.build()?
252        };
253        let revision = $revision.unwrap_or("main".to_string());
254        let api = api.repo(Repo::with_revision(
255            $this.model_id.clone(),
256            RepoType::Model,
257            revision.clone(),
258        ));
259        let model_id = std::path::Path::new(&$this.model_id);
260        let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
261            info!("Using tokenizer.json at `{p}`");
262            PathBuf::from_str(p)?
263        } else {
264            info!("Loading `tokenizer.json` at `{}`", $this.model_id);
265            $crate::api_get_file!(api, "tokenizer.json", model_id)
266        };
267        info!("Loading `config.json` at `{}`", $this.model_id);
268        let config_filename = $crate::api_get_file!(api, "config.json", model_id);
269        let filenames = get_model_paths(
270            revision.clone(),
271            &$token_source,
272            $quantized_model_id.as_ref(),
273            $quantized_filename.as_ref(),
274            &api,
275            &model_id,
276            $loading_uqff,
277        )?;
278        let adapter_paths = get_xlora_paths(
279            $this.model_id.clone(),
280            None, // no xlora
281            $this.lora_adapter_ids.as_ref(),
282            &$token_source,
283            revision.clone(),
284            None, // no xlora
285        )?;
286
287        let mut parsed_modules = Vec::new();
288        let is_local = std::path::Path::new(&$this.model_id).exists();
289        let modules_path = if is_local {
290            model_id.join("modules.json")
291        } else {
292            $crate::api_get_file!(api, "modules.json", model_id)
293        };
294
295        if modules_path.exists() {
296            let modules: Vec<$crate::pipeline::EmbeddingModule> =
297                serde_json::from_str(&std::fs::read_to_string(&modules_path)?)?;
298            for module in modules {
299                match module.ty {
300                    $crate::pipeline::EmbeddingModuleType::Transformer => {
301                        parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Transformer {
302                            path: module.path.clone(),
303                        });
304                    }
305                    $crate::pipeline::EmbeddingModuleType::Pooling => {
306                        parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Pooling {
307                            path: module.path.clone(),
308                            config: $crate::api_get_file!(
309                                api,
310                                &format!("{}/config.json", module.path),
311                                model_id
312                            ),
313                        });
314                    }
315                    $crate::pipeline::EmbeddingModuleType::Dense => {
316                        parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Dense {
317                            path: module.path.clone(),
318                            config: $crate::api_get_file!(
319                                api,
320                                &format!("{}/config.json", module.path),
321                                model_id
322                            ),
323                            model: $crate::api_get_file!(
324                                api,
325                                &format!("{}/model.safetensors", module.path),
326                                model_id
327                            ),
328                        });
329                    }
330                    $crate::pipeline::EmbeddingModuleType::Normalize => {
331                        parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Normalize {
332                            path: module.path.clone(),
333                        });
334                    }
335                }
336            }
337        }
338
339        Ok(Box::new($path_name {
340            tokenizer_filename,
341            config_filename,
342            filenames,
343            adapter_paths,
344            modules: parsed_modules,
345        }))
346    }};
347}
348
349#[doc(hidden)]
350#[macro_export]
351macro_rules! get_uqff_paths {
352    ($from_uqff:expr, $this:expr, $silent:expr) => {{
353        let api = {
354            use $crate::GLOBAL_HF_CACHE;
355            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
356            let mut api = ApiBuilder::from_cache(cache)
357                .with_progress(!$silent)
358                .with_token(get_token(
359                    &$this
360                        .token_source
361                        .read()
362                        .expect("Failed to read token source")
363                        .clone()
364                        .unwrap_or(TokenSource::None),
365                )?);
366            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
367                api = api.with_cache_dir(x.into());
368            }
369            api.build()?
370        };
371        let revision = $this
372            .revision
373            .read()
374            .expect("Failed to read revision")
375            .clone()
376            .unwrap_or("main".to_string());
377        let api = api.repo(Repo::with_revision(
378            $this.model_id.to_string(),
379            RepoType::Model,
380            revision.clone(),
381        ));
382
383        let mut files = Vec::new();
384        for file in $from_uqff {
385            let file = file.display().to_string();
386
387            files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
388        }
389        files
390    }};
391}
392
393#[doc(hidden)]
394#[macro_export]
395macro_rules! get_paths_gguf {
396    (
397        $path_name:ident,
398        $token_source:expr,
399        $revision:expr,
400        $this:expr,
401        $quantized_model_id:expr,
402        $quantized_filenames:expr,
403        $silent:expr
404    ) => {{
405        let api = {
406            use $crate::GLOBAL_HF_CACHE;
407            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
408            let mut api = ApiBuilder::from_cache(cache)
409                .with_progress(!$silent)
410                .with_token(get_token($token_source)?);
411            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
412                api = api.with_cache_dir(x.into());
413            }
414            api.build()?
415        };
416        let revision = $revision.unwrap_or("main".to_string());
417        let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
418        let api = api.repo(Repo::with_revision(
419            this_model_id.clone(),
420            RepoType::Model,
421            revision.clone(),
422        ));
423        let model_id = std::path::Path::new(&this_model_id);
424
425        let dir_list = $crate::api_dir_list!(api, model_id, false)
426            .collect::<Vec<_>>();
427
428        let chat_template = if let Some(ref p) = $this.chat_template {
429            if p.ends_with(".json") || p.ends_with(".jinja") {
430                info!("Using chat template file at `{p}`");
431                Some(PathBuf::from_str(p)?)
432            } else {
433                panic!("Specified chat template file must end with .json or .jinja");
434            }
435        } else {
436            if $this.model_id.is_none() {
437                None
438            } else if dir_list.contains(&"chat_template.jinja".to_string()) {
439                info!("Loading `chat_template.jinja` at `{}`", this_model_id);
440                Some($crate::api_get_file!(
441                    api,
442                    "chat_template.jinja",
443                    model_id
444                ))
445            } else {
446                info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
447                let res = $crate::api_get_file!(
448                    api,
449                    "tokenizer_config.json",
450                    model_id
451                );
452                Some(res)
453            }
454        };
455
456        let filenames = get_model_paths(
457            revision.clone(),
458            &$token_source,
459            Some(&$quantized_model_id),
460            Some(&$quantized_filenames),
461            &api,
462            &model_id,
463            false, // Never loading UQFF
464        )?;
465
466        info!("GGUF file(s) {:?}", filenames);
467        let adapter_paths = get_xlora_paths(
468            this_model_id.clone(),
469            $this.xlora_model_id.as_ref(),
470            $this.lora_adapter_ids.as_ref(),
471            &$token_source,
472            revision.clone(),
473            $this.xlora_order.as_ref(),
474        )?;
475
476        let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
477            info!("Loading `generation_config.json` at `{}`", this_model_id);
478            Some($crate::api_get_file!(
479                api,
480                "generation_config.json",
481                model_id
482            ))
483        } else {
484            None
485        };
486
487        let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string())
488        {
489            info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
490            Some($crate::api_get_file!(
491                api,
492                "preprocessor_config.json",
493                model_id
494            ))
495        } else {
496            None
497        };
498
499        let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
500            info!("Loading `processor_config.json` at `{}`", this_model_id);
501            Some($crate::api_get_file!(
502                api,
503                "processor_config.json",
504                model_id
505            ))
506        } else {
507            None
508        };
509
510        let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) {
511            info!("Loading `tokenizer.json` at `{}`", this_model_id);
512            $crate::api_get_file!(api, "tokenizer.json", model_id)
513        } else {
514            PathBuf::from_str("")?
515        };
516
517        let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
518            info!("Loading `chat_template.json` at `{}`", this_model_id);
519            Some($crate::api_get_file!(
520                api,
521                "chat_template.json",
522                model_id
523            ))
524        } else {
525            None
526        };
527
528        Ok(Box::new($path_name {
529            tokenizer_filename,
530            config_filename: PathBuf::from_str("")?,
531            filenames,
532            adapter_paths,
533            template_filename: chat_template,
534            gen_conf,
535            preprocessor_config,
536            processor_config,
537            chat_template_json_filename,
538        }))
539    }};
540}
541
542#[doc(hidden)]
543#[macro_export]
544macro_rules! normal_model_loader {
545    (
546        $paths:expr,
547        $dtype:expr,
548        $device:expr,
549        $layer_devices:expr,
550        $config:expr,
551        $loader:expr,
552        $silent:expr,
553        $mapper:expr,
554        $loading_isq:expr,
555        $loading_uqff:expr,
556        $real_device:expr,
557        $attention_mechanism:expr,
558        $is_moqe:expr,
559        $multi_progress:expr,
560        $matformer_config:expr,
561    ) => {{
562        let regexes = if $loading_isq && $loading_uqff {
563            // Dummy weights for the layers which will be overwritten...
564            Some(std::sync::Arc::new(if $is_moqe {
565                $loader.isq_layer_regexes_moqe(&$config)?
566            } else {
567                $loader.isq_layer_regexes(&$config)?
568            }))
569        } else {
570            None
571        };
572        let get_device_for_tensor =
573            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
574
575        let vb = from_mmaped_safetensors(
576            $paths.get_weight_filenames().to_vec(),
577            Vec::new(),
578            $dtype,
579            $device,
580            $layer_devices,
581            $silent,
582            regexes,
583            |_| true, // Will be overwritten...
584            get_device_for_tensor,
585        )?;
586
587        $loader.load(
588            &$config,
589            vb,
590            $crate::pipeline::NormalLoadingMetadata {
591                mapper: $mapper,
592                loading_isq: $loading_isq,
593                real_device: $real_device,
594                multi_progress: $multi_progress,
595                matformer_slicing_config: $matformer_config,
596            },
597            $attention_mechanism,
598        )?
599    }};
600}
601
602#[doc(hidden)]
603#[macro_export]
604macro_rules! normal_model_loader_sharded {
605    (
606        $vb:expr,
607        $config:expr,
608        $loader:expr,
609        $mapper:expr,
610        $loading_isq:expr,
611        $real_device:expr,
612        $attention_mechanism:expr,
613        $multi_progress:expr,
614        $matformer_config:expr,
615    ) => {{
616        $loader.load(
617            &$config,
618            $vb,
619            $crate::pipeline::NormalLoadingMetadata {
620                mapper: $mapper,
621                loading_isq: $loading_isq,
622                real_device: $real_device,
623                multi_progress: $multi_progress,
624                matformer_slicing_config: $matformer_config,
625            },
626            $attention_mechanism,
627        )?
628    }};
629}
630
631#[doc(hidden)]
632#[macro_export]
633macro_rules! vision_normal_model_loader {
634    (
635        $paths:expr,
636        $dtype:expr,
637        $device:expr,
638        $layer_devices:expr,
639        $config:expr,
640        $loader:expr,
641        $silent:expr,
642        $mapper:expr,
643        $loading_isq:expr,
644        $loading_uqff:expr,
645        $real_device:expr,
646        $attention_mechanism:expr,
647        $multi_progress:expr,
648        $matformer_config:expr,
649    ) => {{
650        let regexes = if $loading_isq && $loading_uqff {
651            // Dummy weights for the layers which will be overwritten...
652            Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
653        } else {
654            None
655        };
656        let get_device_for_tensor =
657            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
658
659        let vb = from_mmaped_safetensors(
660            $paths.get_weight_filenames().to_vec(),
661            Vec::new(),
662            $dtype,
663            $device,
664            $layer_devices,
665            $silent,
666            regexes,
667            |_| true, // Will be overwritten...
668            get_device_for_tensor,
669        )?;
670
671        $loader.load(
672            &$config,
673            vb,
674            $crate::pipeline::NormalLoadingMetadata {
675                mapper: $mapper,
676                loading_isq: $loading_isq,
677                real_device: $real_device,
678                multi_progress: $multi_progress,
679                matformer_slicing_config: $matformer_config,
680            },
681            $attention_mechanism,
682        )?
683    }};
684}
685
686#[doc(hidden)]
687#[macro_export]
688macro_rules! vision_normal_model_loader_sharded {
689    (
690        $vb:expr,
691        $config:expr,
692        $loader:expr,
693        $mapper:expr,
694        $loading_isq:expr,
695        $real_device:expr,
696        $attention_mechanism:expr,
697        $multi_progress:expr,
698        $matformer_config:expr,
699    ) => {{
700        $loader.load(
701            &$config,
702            $vb,
703            $crate::pipeline::NormalLoadingMetadata {
704                mapper: $mapper,
705                loading_isq: $loading_isq,
706                real_device: $real_device,
707                multi_progress: $multi_progress,
708                matformer_slicing_config: $matformer_config,
709            },
710            $attention_mechanism,
711        )?
712    }};
713}
714
715#[doc(hidden)]
716#[macro_export]
717macro_rules! embedding_normal_model_loader {
718    (
719        $paths:expr,
720        $dtype:expr,
721        $device:expr,
722        $layer_devices:expr,
723        $config:expr,
724        $loader:expr,
725        $silent:expr,
726        $mapper:expr,
727        $loading_isq:expr,
728        $loading_uqff:expr,
729        $real_device:expr,
730        $attention_mechanism:expr,
731        $multi_progress:expr,
732    ) => {{
733        let regexes = if $loading_isq && $loading_uqff {
734            // Dummy weights for the layers which will be overwritten...
735            Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
736        } else {
737            None
738        };
739        let get_device_for_tensor =
740            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
741
742        let vb = from_mmaped_safetensors(
743            $paths.get_weight_filenames().to_vec(),
744            Vec::new(),
745            $dtype,
746            $device,
747            $layer_devices,
748            $silent,
749            regexes,
750            |_| true, // Will be overwritten...
751            get_device_for_tensor,
752        )?;
753
754        $loader.load(
755            &$config,
756            vb,
757            $crate::pipeline::NormalLoadingMetadata {
758                mapper: $mapper,
759                loading_isq: $loading_isq,
760                real_device: $real_device,
761                multi_progress: $multi_progress,
762                matformer_slicing_config: None,
763            },
764            $attention_mechanism,
765        )?
766    }};
767}
768
769#[doc(hidden)]
770#[macro_export]
771macro_rules! embedding_normal_model_loader_sharded {
772    (
773        $vb:expr,
774        $config:expr,
775        $loader:expr,
776        $mapper:expr,
777        $loading_isq:expr,
778        $real_device:expr,
779        $attention_mechanism:expr,
780        $multi_progress:expr,
781    ) => {{
782        $loader.load(
783            &$config,
784            $vb,
785            $crate::pipeline::NormalLoadingMetadata {
786                mapper: $mapper,
787                loading_isq: $loading_isq,
788                real_device: $real_device,
789                multi_progress: $multi_progress,
790                matformer_slicing_config: None,
791            },
792            $attention_mechanism,
793        )?
794    }};
795}
796
797#[doc(hidden)]
798#[macro_export]
799macro_rules! xlora_model_loader {
800    (
801        $paths:expr,
802        $dtype:expr,
803        $device:expr,
804        $layer_devices:expr,
805        $config:expr,
806        $loader:expr,
807        $silent:expr,
808        $mapper:expr,
809        $loading_isq:expr,
810        $real_device:expr,
811        $multi_progress:expr,
812        $matformer_config:expr,
813    ) => {{
814        // TODO: remove lora_preload_adapter_info
815        let $crate::pipeline::AdapterPaths::XLora {
816            adapter_configs,
817            adapter_safetensors,
818            classifier_path,
819            xlora_order,
820            xlora_config,
821            lora_preload_adapter_info: _,
822        } = $paths.get_adapter_paths()
823        else {
824            unreachable!()
825        };
826
827        let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
828        safetensors_paths.push(classifier_path.as_ref().unwrap());
829        let get_device_for_tensor =
830            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
831
832        let vb = from_mmaped_safetensors(
833            safetensors_paths
834                .iter()
835                .map(|x| (*x).to_owned())
836                .collect::<Vec<_>>(),
837            adapter_safetensors
838                .as_ref()
839                .unwrap()
840                .iter()
841                .map(|(_, x)| (*x).to_owned())
842                .collect::<Vec<_>>(),
843            $dtype,
844            $device,
845            $layer_devices,
846            $silent,
847            None,
848            |_| true,
849            get_device_for_tensor,
850        )?;
851
852        $loader.load_xlora(
853            &$config,
854            vb,
855            adapter_configs.as_ref().unwrap(),
856            Some(xlora_config.as_ref().unwrap().clone()),
857            xlora_order.as_ref().unwrap().clone(),
858            $crate::pipeline::NormalLoadingMetadata {
859                mapper: $mapper,
860                loading_isq: $loading_isq,
861                real_device: $real_device,
862                multi_progress: $multi_progress,
863                matformer_slicing_config: $matformer_config,
864            },
865            &None,
866        )?
867    }};
868}
869
870#[doc(hidden)]
871#[macro_export]
872macro_rules! lora_model_loader {
873    (
874        $paths:expr,
875        $dtype:expr,
876        $device:expr,
877        $layer_devices:expr,
878        $config:expr,
879        $loader:expr,
880        $silent:expr,
881        $mapper:expr,
882        $loading_isq:expr,
883        $loading_uqff:expr,
884        $real_device:expr,
885        $attention_mechanism:expr,
886        $is_moqe:expr,
887        $multi_progress:expr,
888        $matformer_config:expr,
889    ) => {{
890        let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
891        else {
892            unreachable!()
893        };
894
895        let regexes = if $loading_isq && $loading_uqff {
896            // Dummy weights for the layers which will be overwritten...
897            Some(std::sync::Arc::new(if $is_moqe {
898                $loader.isq_layer_regexes_moqe(&$config)?
899            } else {
900                $loader.isq_layer_regexes(&$config)?
901            }))
902        } else {
903            None
904        };
905        let get_device_for_tensor =
906            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
907
908        let vb = from_mmaped_safetensors(
909            $paths.get_weight_filenames().to_vec(),
910            Vec::new(),
911            $dtype,
912            $device,
913            $layer_devices,
914            $silent,
915            regexes,
916            |_| true, // Will be overwritten...
917            get_device_for_tensor.clone(),
918        )?;
919
920        for $crate::pipeline::LoraAdapterPaths {
921            adapter_path,
922            lora_config,
923        } in lora_adapter_paths
924        {
925            let lora_vb = from_mmaped_safetensors(
926                vec![adapter_path.clone()],
927                Vec::new(),
928                $dtype,
929                $device,
930                $layer_devices,
931                $silent,
932                None,
933                |_| true,
934                get_device_for_tensor.clone(),
935            )?;
936
937            mistralrs_quant::push_applied_lora(mistralrs_quant::LoraAdapter {
938                config: lora_config.clone(),
939                weights: lora_vb,
940            });
941        }
942
943        $loader.load(
944            &$config,
945            vb,
946            $crate::pipeline::NormalLoadingMetadata {
947                mapper: $mapper,
948                loading_isq: $loading_isq,
949                real_device: $real_device,
950                multi_progress: $multi_progress,
951                matformer_slicing_config: $matformer_config,
952            },
953            $attention_mechanism,
954        )?
955    }};
956}