mistralrs_core/pipeline/
macros.rs

1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4    ($api:expr, $model_id:expr, $should_panic:expr) => {
5        if std::path::Path::new($model_id).exists() {
6            let listing = std::fs::read_dir($model_id);
7            if listing.is_err() {
8                panic!("Cannot list directory {:?}", $model_id)
9            }
10            let listing = listing.unwrap();
11            listing
12                .into_iter()
13                .map(|s| {
14                    s.unwrap()
15                        .path()
16                        .file_name()
17                        .unwrap() // Should never terminate in `..`
18                        .to_str()
19                        .expect("Could not convert to str")
20                        .to_string()
21                })
22                .collect::<Vec<String>>()
23                .into_iter()
24        } else {
25            let sanitized_id = std::path::Path::new($model_id)
26                .display()
27                .to_string()
28                .replace("/", "-");
29
30            let home_folder = if dirs::home_dir().is_some() {
31                let mut path = dirs::home_dir().unwrap();
32                path.push(".cache/huggingface/hub/");
33                if !path.exists() {
34                    let _ = std::fs::create_dir_all(&path);
35                }
36                path
37            } else {
38                "./".into()
39            };
40
41            let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
42                .map(std::path::PathBuf::from)
43                .unwrap_or(home_folder.into());
44            let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
45            if std::path::Path::new(&cache_file).exists() {
46                use std::io::Read;
47                // Read from cache
48                let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
49                let mut contents = String::new();
50                file.read_to_string(&mut contents)
51                    .expect("Could not read cache file");
52                let cache: $crate::pipeline::FileListCache =
53                    serde_json::from_str(&contents).expect("Could not parse cache JSON");
54                tracing::info!("Read from cache file {:?}", cache_file);
55                cache.files.into_iter()
56            } else {
57                $api.info()
58                    .map(|repo| {
59                        let files: Vec<String> = repo
60                            .siblings
61                            .iter()
62                            .map(|x| x.rfilename.clone())
63                            .collect::<Vec<String>>();
64                        // Save to cache
65                        let cache = $crate::pipeline::FileListCache {
66                            files: files.clone(),
67                        };
68                        let json = serde_json::to_string_pretty(&cache)
69                            .expect("Could not serialize cache");
70                        let ret = std::fs::write(&cache_file, json);
71                        tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
72                        files
73                    })
74                    .unwrap_or_else(|e| {
75                        if $should_panic {
76                            panic!("Could not get directory listing from API: {:?}", e)
77                        } else {
78                            tracing::warn!("Could not get directory listing from API: {:?}", e);
79                            Vec::<String>::new()
80                        }
81                    })
82                    .into_iter()
83            }
84        }
85    };
86}
87
88#[doc(hidden)]
89#[macro_export]
90macro_rules! api_get_file {
91    ($api:expr, $file:expr, $model_id:expr) => {
92        if std::path::Path::new($model_id).exists() {
93            let path = $model_id.join($file);
94            if !path.exists() {
95                panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
96            }
97            info!("Loading `{}` locally at `{}`", &$file, path.display());
98            path
99        } else {
100            $api.get($file)
101                .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
102        }
103    };
104}
105
106#[doc(hidden)]
107#[macro_export]
108macro_rules! get_paths {
109    (
110        $path_name:ident,
111        $token_source:expr,
112        $revision:expr,
113        $this:expr,
114        $quantized_model_id:expr,
115        $quantized_filename:expr,
116        $silent:expr,
117        $loading_uqff:expr
118    ) => {{
119        let api = {
120            use $crate::GLOBAL_HF_CACHE;
121            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
122            let mut api = ApiBuilder::from_cache(cache)
123                .with_progress(!$silent)
124                .with_token(get_token($token_source)?);
125            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
126                api = api.with_cache_dir(x.into());
127            }
128            api.build()?
129        };
130        let revision = $revision.unwrap_or("main".to_string());
131        let api = api.repo(Repo::with_revision(
132            $this.model_id.clone(),
133            RepoType::Model,
134            revision.clone(),
135        ));
136        let model_id = std::path::Path::new(&$this.model_id);
137        let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
138            info!("Using tokenizer.json at `{p}`");
139            PathBuf::from_str(p)?
140        } else {
141            info!("Loading `tokenizer.json` at `{}`", $this.model_id);
142            $crate::api_get_file!(api, "tokenizer.json", model_id)
143        };
144        info!("Loading `config.json` at `{}`", $this.model_id);
145        let config_filename = $crate::api_get_file!(api, "config.json", model_id);
146        let filenames = get_model_paths(
147            revision.clone(),
148            &$token_source,
149            $quantized_model_id.as_ref(),
150            $quantized_filename.as_ref(),
151            &api,
152            &model_id,
153            $loading_uqff,
154        )?;
155        let adapter_paths = get_xlora_paths(
156            $this.model_id.clone(),
157            $this.xlora_model_id.as_ref(),
158            $this.lora_adapter_ids.as_ref(),
159            &$token_source,
160            revision.clone(),
161            $this.xlora_order.as_ref(),
162        )?;
163        let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();
164
165        let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
166            info!("Loading `generation_config.json` at `{}`", $this.model_id);
167            Some($crate::api_get_file!(
168                api,
169                "generation_config.json",
170                model_id
171            ))
172        } else {
173            None
174        };
175        let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
176            info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
177            Some($crate::api_get_file!(
178                api,
179                "preprocessor_config.json",
180                model_id
181            ))
182        } else {
183            None
184        };
185        let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
186            info!("Loading `processor_config.json` at `{}`", $this.model_id);
187            Some($crate::api_get_file!(
188                api,
189                "processor_config.json",
190                model_id
191            ))
192        } else {
193            None
194        };
195        let template_filename = if let Some(ref p) = $this.chat_template {
196            info!("Using chat template file at `{p}`");
197            Some(PathBuf::from_str(p)?)
198        } else if dir_list.contains(&"chat_template.jinja".to_string()) {
199            info!("Loading `chat_template.jinja` at `{}`", $this.model_id);
200            Some($crate::api_get_file!(api, "chat_template.jinja", model_id))
201        } else {
202            info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
203            Some($crate::api_get_file!(
204                api,
205                "tokenizer_config.json",
206                model_id
207            ))
208        };
209        let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
210            info!("Loading `chat_template.json` at `{}`", $this.model_id);
211            Some($crate::api_get_file!(api, "chat_template.json", model_id))
212        } else {
213            None
214        };
215        Ok(Box::new($path_name {
216            tokenizer_filename,
217            config_filename,
218            filenames,
219            adapter_paths,
220            template_filename,
221            gen_conf,
222            preprocessor_config,
223            processor_config,
224            chat_template_json_filename,
225        }))
226    }};
227}
228
229#[doc(hidden)]
230#[macro_export]
231macro_rules! get_uqff_paths {
232    ($from_uqff:expr, $this:expr, $silent:expr) => {{
233        let api = {
234            use $crate::GLOBAL_HF_CACHE;
235            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
236            let mut api = ApiBuilder::from_cache(cache)
237                .with_progress(!$silent)
238                .with_token(get_token(
239                    &$this
240                        .token_source
241                        .read()
242                        .expect("Failed to read token source")
243                        .clone()
244                        .unwrap_or(TokenSource::None),
245                )?);
246            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
247                api = api.with_cache_dir(x.into());
248            }
249            api.build()?
250        };
251        let revision = $this
252            .revision
253            .read()
254            .expect("Failed to read revision")
255            .clone()
256            .unwrap_or("main".to_string());
257        let api = api.repo(Repo::with_revision(
258            $this.model_id.to_string(),
259            RepoType::Model,
260            revision.clone(),
261        ));
262
263        let mut files = Vec::new();
264        for file in $from_uqff {
265            let file = file.display().to_string();
266
267            files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
268        }
269        files
270    }};
271}
272
273#[doc(hidden)]
274#[macro_export]
275macro_rules! get_paths_gguf {
276    (
277        $path_name:ident,
278        $token_source:expr,
279        $revision:expr,
280        $this:expr,
281        $quantized_model_id:expr,
282        $quantized_filenames:expr,
283        $silent:expr
284    ) => {{
285        let api = {
286            use $crate::GLOBAL_HF_CACHE;
287            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
288            let mut api = ApiBuilder::from_cache(cache)
289                .with_progress(!$silent)
290                .with_token(get_token($token_source)?);
291            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
292                api = api.with_cache_dir(x.into());
293            }
294            api.build()?
295        };
296        let revision = $revision.unwrap_or("main".to_string());
297        let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
298        let api = api.repo(Repo::with_revision(
299            this_model_id.clone(),
300            RepoType::Model,
301            revision.clone(),
302        ));
303        let model_id = std::path::Path::new(&this_model_id);
304
305        let dir_list = $crate::api_dir_list!(api, model_id, false)
306            .collect::<Vec<_>>();
307
308        let chat_template = if let Some(ref p) = $this.chat_template {
309            if p.ends_with(".json") || p.ends_with(".jinja") {
310                info!("Using chat template file at `{p}`");
311                Some(PathBuf::from_str(p)?)
312            } else {
313                panic!("Specified chat template file must end with .json or .jinja");
314            }
315        } else {
316            if $this.model_id.is_none() {
317                None
318            } else if dir_list.contains(&"chat_template.jinja".to_string()) {
319                info!("Loading `chat_template.jinja` at `{}`", this_model_id);
320                Some($crate::api_get_file!(
321                    api,
322                    "chat_template.jinja",
323                    model_id
324                ))
325            } else {
326                info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
327                let res = $crate::api_get_file!(
328                    api,
329                    "tokenizer_config.json",
330                    model_id
331                );
332                Some(res)
333            }
334        };
335
336        let filenames = get_model_paths(
337            revision.clone(),
338            &$token_source,
339            Some(&$quantized_model_id),
340            Some(&$quantized_filenames),
341            &api,
342            &model_id,
343            false, // Never loading UQFF
344        )?;
345
346        info!("GGUF file(s) {:?}", filenames);
347        let adapter_paths = get_xlora_paths(
348            this_model_id.clone(),
349            $this.xlora_model_id.as_ref(),
350            $this.lora_adapter_ids.as_ref(),
351            &$token_source,
352            revision.clone(),
353            $this.xlora_order.as_ref(),
354        )?;
355
356        let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
357            info!("Loading `generation_config.json` at `{}`", this_model_id);
358            Some($crate::api_get_file!(
359                api,
360                "generation_config.json",
361                model_id
362            ))
363        } else {
364            None
365        };
366
367        let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string())
368        {
369            info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
370            Some($crate::api_get_file!(
371                api,
372                "preprocessor_config.json",
373                model_id
374            ))
375        } else {
376            None
377        };
378
379        let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
380            info!("Loading `processor_config.json` at `{}`", this_model_id);
381            Some($crate::api_get_file!(
382                api,
383                "processor_config.json",
384                model_id
385            ))
386        } else {
387            None
388        };
389
390        let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) {
391            info!("Loading `tokenizer.json` at `{}`", this_model_id);
392            $crate::api_get_file!(api, "tokenizer.json", model_id)
393        } else {
394            PathBuf::from_str("")?
395        };
396
397        let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
398            info!("Loading `chat_template.json` at `{}`", this_model_id);
399            Some($crate::api_get_file!(
400                api,
401                "chat_template.json",
402                model_id
403            ))
404        } else {
405            None
406        };
407
408        Ok(Box::new($path_name {
409            tokenizer_filename,
410            config_filename: PathBuf::from_str("")?,
411            filenames,
412            adapter_paths,
413            template_filename: chat_template,
414            gen_conf,
415            preprocessor_config,
416            processor_config,
417            chat_template_json_filename,
418        }))
419    }};
420}
421
422#[doc(hidden)]
423#[macro_export]
424macro_rules! normal_model_loader {
425    (
426        $paths:expr,
427        $dtype:expr,
428        $device:expr,
429        $layer_devices:expr,
430        $config:expr,
431        $loader:expr,
432        $silent:expr,
433        $mapper:expr,
434        $loading_isq:expr,
435        $loading_uqff:expr,
436        $real_device:expr,
437        $attention_mechanism:expr,
438        $is_moqe:expr,
439        $multi_progress:expr,
440        $matformer_config:expr,
441    ) => {{
442        let regexes = if $loading_isq && $loading_uqff {
443            // Dummy weights for the layers which will be overwritten...
444            Some(std::sync::Arc::new(if $is_moqe {
445                $loader.isq_layer_regexes_moqe(&$config)?
446            } else {
447                $loader.isq_layer_regexes(&$config)?
448            }))
449        } else {
450            None
451        };
452        let get_device_for_tensor =
453            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
454
455        let vb = from_mmaped_safetensors(
456            $paths.get_weight_filenames().to_vec(),
457            Vec::new(),
458            $dtype,
459            $device,
460            $layer_devices,
461            $silent,
462            regexes,
463            |_| true, // Will be overwritten...
464            get_device_for_tensor,
465        )?;
466
467        $loader.load(
468            &$config,
469            vb,
470            $crate::pipeline::NormalLoadingMetadata {
471                mapper: $mapper,
472                loading_isq: $loading_isq,
473                real_device: $real_device,
474                multi_progress: $multi_progress,
475                matformer_slicing_config: $matformer_config,
476            },
477            $attention_mechanism,
478        )?
479    }};
480}
481
482#[doc(hidden)]
483#[macro_export]
484macro_rules! normal_model_loader_sharded {
485    (
486        $vb:expr,
487        $config:expr,
488        $loader:expr,
489        $mapper:expr,
490        $loading_isq:expr,
491        $real_device:expr,
492        $attention_mechanism:expr,
493        $multi_progress:expr,
494        $matformer_config:expr,
495    ) => {{
496        $loader.load(
497            &$config,
498            $vb,
499            $crate::pipeline::NormalLoadingMetadata {
500                mapper: $mapper,
501                loading_isq: $loading_isq,
502                real_device: $real_device,
503                multi_progress: $multi_progress,
504                matformer_slicing_config: $matformer_config,
505            },
506            $attention_mechanism,
507        )?
508    }};
509}
510
511#[doc(hidden)]
512#[macro_export]
513macro_rules! vision_normal_model_loader {
514    (
515        $paths:expr,
516        $dtype:expr,
517        $device:expr,
518        $layer_devices:expr,
519        $config:expr,
520        $loader:expr,
521        $silent:expr,
522        $mapper:expr,
523        $loading_isq:expr,
524        $loading_uqff:expr,
525        $real_device:expr,
526        $attention_mechanism:expr,
527        $multi_progress:expr,
528        $matformer_config:expr,
529    ) => {{
530        let regexes = if $loading_isq && $loading_uqff {
531            // Dummy weights for the layers which will be overwritten...
532            Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
533        } else {
534            None
535        };
536        let get_device_for_tensor =
537            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
538
539        let vb = from_mmaped_safetensors(
540            $paths.get_weight_filenames().to_vec(),
541            Vec::new(),
542            $dtype,
543            $device,
544            $layer_devices,
545            $silent,
546            regexes,
547            |_| true, // Will be overwritten...
548            get_device_for_tensor,
549        )?;
550
551        $loader.load(
552            &$config,
553            vb,
554            $crate::pipeline::NormalLoadingMetadata {
555                mapper: $mapper,
556                loading_isq: $loading_isq,
557                real_device: $real_device,
558                multi_progress: $multi_progress,
559                matformer_slicing_config: $matformer_config,
560            },
561            $attention_mechanism,
562        )?
563    }};
564}
565
566#[doc(hidden)]
567#[macro_export]
568macro_rules! vision_normal_model_loader_sharded {
569    (
570        $vb:expr,
571        $config:expr,
572        $loader:expr,
573        $mapper:expr,
574        $loading_isq:expr,
575        $real_device:expr,
576        $attention_mechanism:expr,
577        $multi_progress:expr,
578        $matformer_config:expr,
579    ) => {{
580        $loader.load(
581            &$config,
582            $vb,
583            $crate::pipeline::NormalLoadingMetadata {
584                mapper: $mapper,
585                loading_isq: $loading_isq,
586                real_device: $real_device,
587                multi_progress: $multi_progress,
588                matformer_slicing_config: $matformer_config,
589            },
590            $attention_mechanism,
591        )?
592    }};
593}
594
595#[doc(hidden)]
596#[macro_export]
597macro_rules! xlora_model_loader {
598    (
599        $paths:expr,
600        $dtype:expr,
601        $device:expr,
602        $layer_devices:expr,
603        $config:expr,
604        $loader:expr,
605        $silent:expr,
606        $mapper:expr,
607        $loading_isq:expr,
608        $real_device:expr,
609        $multi_progress:expr,
610        $matformer_config:expr,
611    ) => {{
612        // TODO: remove lora_preload_adapter_info
613        let $crate::pipeline::AdapterPaths::XLora {
614            adapter_configs,
615            adapter_safetensors,
616            classifier_path,
617            xlora_order,
618            xlora_config,
619            lora_preload_adapter_info: _,
620        } = $paths.get_adapter_paths()
621        else {
622            unreachable!()
623        };
624
625        let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
626        safetensors_paths.push(classifier_path.as_ref().unwrap());
627        let get_device_for_tensor =
628            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
629
630        let vb = from_mmaped_safetensors(
631            safetensors_paths
632                .iter()
633                .map(|x| (*x).to_owned())
634                .collect::<Vec<_>>(),
635            adapter_safetensors
636                .as_ref()
637                .unwrap()
638                .iter()
639                .map(|(_, x)| (*x).to_owned())
640                .collect::<Vec<_>>(),
641            $dtype,
642            $device,
643            $layer_devices,
644            $silent,
645            None,
646            |_| true,
647            get_device_for_tensor,
648        )?;
649
650        $loader.load_xlora(
651            &$config,
652            vb,
653            adapter_configs.as_ref().unwrap(),
654            Some(xlora_config.as_ref().unwrap().clone()),
655            xlora_order.as_ref().unwrap().clone(),
656            $crate::pipeline::NormalLoadingMetadata {
657                mapper: $mapper,
658                loading_isq: $loading_isq,
659                real_device: $real_device,
660                multi_progress: $multi_progress,
661                matformer_slicing_config: $matformer_config,
662            },
663            &None,
664        )?
665    }};
666}
667
668#[doc(hidden)]
669#[macro_export]
670macro_rules! lora_model_loader {
671    (
672        $paths:expr,
673        $dtype:expr,
674        $device:expr,
675        $layer_devices:expr,
676        $config:expr,
677        $loader:expr,
678        $silent:expr,
679        $mapper:expr,
680        $loading_isq:expr,
681        $loading_uqff:expr,
682        $real_device:expr,
683        $attention_mechanism:expr,
684        $is_moqe:expr,
685        $multi_progress:expr,
686        $matformer_config:expr,
687    ) => {{
688        let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
689        else {
690            unreachable!()
691        };
692
693        let regexes = if $loading_isq && $loading_uqff {
694            // Dummy weights for the layers which will be overwritten...
695            Some(std::sync::Arc::new(if $is_moqe {
696                $loader.isq_layer_regexes_moqe(&$config)?
697            } else {
698                $loader.isq_layer_regexes(&$config)?
699            }))
700        } else {
701            None
702        };
703        let get_device_for_tensor =
704            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
705
706        let vb = from_mmaped_safetensors(
707            $paths.get_weight_filenames().to_vec(),
708            Vec::new(),
709            $dtype,
710            $device,
711            $layer_devices,
712            $silent,
713            regexes,
714            |_| true, // Will be overwritten...
715            get_device_for_tensor.clone(),
716        )?;
717
718        for $crate::pipeline::LoraAdapterPaths {
719            adapter_path,
720            lora_config,
721        } in lora_adapter_paths
722        {
723            let lora_vb = from_mmaped_safetensors(
724                vec![adapter_path.clone()],
725                Vec::new(),
726                $dtype,
727                $device,
728                $layer_devices,
729                $silent,
730                None,
731                |_| true,
732                get_device_for_tensor.clone(),
733            )?;
734
735            mistralrs_quant::push_applied_lora(mistralrs_quant::LoraAdapter {
736                config: lora_config.clone(),
737                weights: lora_vb,
738            });
739        }
740
741        $loader.load(
742            &$config,
743            vb,
744            $crate::pipeline::NormalLoadingMetadata {
745                mapper: $mapper,
746                loading_isq: $loading_isq,
747                real_device: $real_device,
748                multi_progress: $multi_progress,
749                matformer_slicing_config: $matformer_config,
750            },
751            $attention_mechanism,
752        )?
753    }};
754}