mistralrs_core/pipeline/
macros.rs

1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4    ($api:expr, $model_id:expr) => {
5        if std::path::Path::new($model_id).exists() {
6            let listing = std::fs::read_dir($model_id);
7            if listing.is_err() {
8                panic!("Cannot list directory {:?}", $model_id)
9            }
10            let listing = listing.unwrap();
11            listing
12                .into_iter()
13                .map(|s| {
14                    s.unwrap()
15                        .path()
16                        .file_name()
17                        .unwrap() // Should never terminate in `..`
18                        .to_str()
19                        .expect("Could not convert to str")
20                        .to_string()
21                })
22                .collect::<Vec<String>>()
23                .into_iter()
24        } else {
25            $api.info()
26                .map(|repo| {
27                    repo.siblings
28                        .iter()
29                        .map(|x| x.rfilename.clone())
30                        .collect::<Vec<String>>()
31                })
32                .unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e))
33                .into_iter()
34        }
35    };
36}
37
38#[doc(hidden)]
39#[macro_export]
40macro_rules! api_get_file {
41    ($api:expr, $file:expr, $model_id:expr) => {
42        if std::path::Path::new($model_id).exists() {
43            let path = $model_id.join($file);
44            if !path.exists() {
45                panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
46            }
47            info!("Loading `{}` locally at `{}`", &$file, path.display());
48            path
49        } else {
50            $api.get($file)
51                .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
52        }
53    };
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! get_paths {
59    (
60        $path_name:ident,
61        $token_source:expr,
62        $revision:expr,
63        $this:expr,
64        $quantized_model_id:expr,
65        $quantized_filename:expr,
66        $silent:expr,
67        $loading_uqff:expr
68    ) => {{
69        let api = {
70            use $crate::GLOBAL_HF_CACHE;
71            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
72            let mut api = ApiBuilder::from_cache(cache)
73                .with_progress(!$silent)
74                .with_token(get_token($token_source)?);
75            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
76                api = api.with_cache_dir(x.into());
77            }
78            api.build()?
79        };
80        let revision = $revision.unwrap_or("main".to_string());
81        let api = api.repo(Repo::with_revision(
82            $this.model_id.clone(),
83            RepoType::Model,
84            revision.clone(),
85        ));
86        let model_id = std::path::Path::new(&$this.model_id);
87        let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
88            info!("Using tokenizer.json at `{p}`");
89            PathBuf::from_str(p)?
90        } else {
91            info!("Loading `tokenizer.json` at `{}`", $this.model_id);
92            $crate::api_get_file!(api, "tokenizer.json", model_id)
93        };
94        info!("Loading `config.json` at `{}`", $this.model_id);
95        let config_filename = $crate::api_get_file!(api, "config.json", model_id);
96        let filenames = get_model_paths(
97            revision.clone(),
98            &$token_source,
99            &$quantized_model_id,
100            &$quantized_filename,
101            &api,
102            &model_id,
103            $loading_uqff,
104        )?;
105        let XLoraPaths {
106            adapter_configs,
107            adapter_safetensors,
108            classifier_path,
109            xlora_order,
110            xlora_config,
111            lora_preload_adapter_info,
112        } = get_xlora_paths(
113            $this.model_id.clone(),
114            &$this.xlora_model_id,
115            &$token_source,
116            revision.clone(),
117            &$this.xlora_order,
118        )?;
119        let gen_conf = if $crate::api_dir_list!(api, model_id)
120            .collect::<Vec<_>>()
121            .contains(&"generation_config.json".to_string())
122        {
123            info!("Loading `generation_config.json` at `{}`", $this.model_id);
124            Some($crate::api_get_file!(
125                api,
126                "generation_config.json",
127                model_id
128            ))
129        } else {
130            None
131        };
132        let preprocessor_config = if $crate::api_dir_list!(api, model_id)
133            .collect::<Vec<_>>()
134            .contains(&"preprocessor_config.json".to_string())
135        {
136            info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
137            Some($crate::api_get_file!(
138                api,
139                "preprocessor_config.json",
140                model_id
141            ))
142        } else {
143            None
144        };
145        let processor_config = if $crate::api_dir_list!(api, model_id)
146            .collect::<Vec<_>>()
147            .contains(&"processor_config.json".to_string())
148        {
149            info!("Loading `processor_config.json` at `{}`", $this.model_id);
150            Some($crate::api_get_file!(
151                api,
152                "processor_config.json",
153                model_id
154            ))
155        } else {
156            None
157        };
158        let template_filename = if let Some(ref p) = $this.chat_template {
159            info!("Using chat template file at `{p}`");
160            Some(PathBuf::from_str(p)?)
161        } else {
162            info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
163            Some($crate::api_get_file!(
164                api,
165                "tokenizer_config.json",
166                model_id
167            ))
168        };
169        let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
170            .collect::<Vec<_>>()
171            .contains(&"chat_template.json".to_string())
172        {
173            info!("Loading `chat_template.json` at `{}`", $this.model_id);
174            Some($crate::api_get_file!(api, "chat_template.json", model_id))
175        } else {
176            None
177        };
178        Ok(Box::new($path_name {
179            tokenizer_filename,
180            config_filename,
181            filenames,
182            xlora_adapter_configs: adapter_configs,
183            xlora_adapter_filenames: adapter_safetensors,
184            classifier_path,
185            classifier_config: xlora_config,
186            xlora_ordering: xlora_order,
187            template_filename,
188            gen_conf,
189            lora_preload_adapter_info,
190            preprocessor_config,
191            processor_config,
192            chat_template_json_filename,
193        }))
194    }};
195}
196
197#[doc(hidden)]
198#[macro_export]
199macro_rules! get_uqff_paths {
200    ($from_uqff:expr, $this:expr, $silent:expr) => {{
201        let api = {
202            use $crate::GLOBAL_HF_CACHE;
203            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
204            let mut api = ApiBuilder::from_cache(cache)
205                .with_progress(!$silent)
206                .with_token(get_token(
207                    &$this
208                        .token_source
209                        .read()
210                        .expect("Failed to read token source")
211                        .clone()
212                        .unwrap_or(TokenSource::None),
213                )?);
214            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
215                api = api.with_cache_dir(x.into());
216            }
217            api.build()?
218        };
219        let revision = $this
220            .revision
221            .read()
222            .expect("Failed to read revision")
223            .clone()
224            .unwrap_or("main".to_string());
225        let api = api.repo(Repo::with_revision(
226            $this.model_id.to_string(),
227            RepoType::Model,
228            revision.clone(),
229        ));
230
231        let file = $from_uqff.display().to_string();
232
233        api_get_file!(api, &file, Path::new(&$this.model_id))
234    }};
235}
236
237#[doc(hidden)]
238#[macro_export]
239macro_rules! get_paths_gguf {
240    (
241        $path_name:ident,
242        $token_source:expr,
243        $revision:expr,
244        $this:expr,
245        $quantized_model_id:expr,
246        $quantized_filenames:expr,
247        $silent:expr
248    ) => {{
249        let api = {
250            use $crate::GLOBAL_HF_CACHE;
251            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
252            let mut api = ApiBuilder::from_cache(cache)
253                .with_progress(!$silent)
254                .with_token(get_token($token_source)?);
255            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
256                api = api.with_cache_dir(x.into());
257            }
258            api.build()?
259        };
260        let revision = $revision.unwrap_or("main".to_string());
261        let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
262        let api = api.repo(Repo::with_revision(
263            this_model_id.clone(),
264            RepoType::Model,
265            revision.clone(),
266        ));
267        let model_id = std::path::Path::new(&this_model_id);
268
269        let chat_template = if let Some(ref p) = $this.chat_template {
270            if p.ends_with(".json") {
271                info!("Using chat template file at `{p}`");
272                Some(PathBuf::from_str(p)?)
273            } else {
274                panic!("Specified chat template file must end with .json");
275            }
276        } else {
277            if $this.model_id.is_none() {
278                None
279            } else {
280                info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
281                let res = $crate::api_get_file!(
282                    api,
283                    "tokenizer_config.json",
284                    model_id
285                );
286                Some(res)
287            }
288        };
289
290        let filenames = get_model_paths(
291            revision.clone(),
292            &$token_source,
293            &Some($quantized_model_id),
294            &Some($quantized_filenames),
295            &api,
296            &model_id,
297            false, // Never loading UQFF
298        )?;
299
300        let XLoraPaths {
301            adapter_configs,
302            adapter_safetensors,
303            classifier_path,
304            xlora_order,
305            xlora_config,
306            lora_preload_adapter_info,
307        } = get_xlora_paths(
308            this_model_id.clone(),
309            &$this.xlora_model_id,
310            &$token_source,
311            revision.clone(),
312            &$this.xlora_order,
313        )?;
314
315        let gen_conf = if $crate::api_dir_list!(api, model_id)
316            .collect::<Vec<_>>()
317            .contains(&"generation_config.json".to_string())
318        {
319            info!("Loading `generation_config.json` at `{}`", this_model_id);
320            Some($crate::api_get_file!(
321                api,
322                "generation_config.json",
323                model_id
324            ))
325        } else {
326            None
327        };
328
329        let preprocessor_config = if $crate::api_dir_list!(api, model_id)
330            .collect::<Vec<_>>()
331            .contains(&"preprocessor_config.json".to_string())
332        {
333            info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
334            Some($crate::api_get_file!(
335                api,
336                "preprocessor_config.json",
337                model_id
338            ))
339        } else {
340            None
341        };
342
343        let processor_config = if $crate::api_dir_list!(api, model_id)
344            .collect::<Vec<_>>()
345            .contains(&"processor_config.json".to_string())
346        {
347            info!("Loading `processor_config.json` at `{}`", this_model_id);
348            Some($crate::api_get_file!(
349                api,
350                "processor_config.json",
351                model_id
352            ))
353        } else {
354            None
355        };
356
357        let tokenizer_filename = if $this.model_id.is_some() {
358            info!("Loading `tokenizer.json` at `{}`", this_model_id);
359            $crate::api_get_file!(api, "tokenizer.json", model_id)
360        } else {
361            PathBuf::from_str("")?
362        };
363
364        let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
365            .collect::<Vec<_>>()
366            .contains(&"chat_template.json".to_string())
367        {
368            info!("Loading `chat_template.json` at `{}`", this_model_id);
369            Some($crate::api_get_file!(
370                api,
371                "chat_template.json",
372                model_id
373            ))
374        } else {
375            None
376        };
377
378        Ok(Box::new($path_name {
379            tokenizer_filename,
380            config_filename: PathBuf::from_str("")?,
381            filenames,
382            xlora_adapter_configs: adapter_configs,
383            xlora_adapter_filenames: adapter_safetensors,
384            classifier_path,
385            classifier_config: xlora_config,
386            xlora_ordering: xlora_order,
387            template_filename: chat_template,
388            gen_conf,
389            lora_preload_adapter_info,
390            preprocessor_config,
391            processor_config,
392            chat_template_json_filename,
393        }))
394    }};
395}
396
397#[doc(hidden)]
398#[macro_export]
399macro_rules! normal_model_loader {
400    (
401        $paths:expr,
402        $dtype:expr,
403        $device:expr,
404        $layer_devices:expr,
405        $config:expr,
406        $loader:expr,
407        $use_flash_attn:expr,
408        $silent:expr,
409        $mapper:expr,
410        $loading_isq:expr,
411        $loading_uqff:expr,
412        $real_device:expr,
413        $attention_mechanism:expr,
414        $is_moqe:expr,
415        $multi_progress:expr,
416    ) => {{
417        let regexes = if $loading_isq && $loading_uqff {
418            // Dummy weights for the layers which will be overwritten...
419            Some(std::sync::Arc::new(if $is_moqe {
420                $loader.isq_layer_regexes_moqe(&$config)?
421            } else {
422                $loader.isq_layer_regexes(&$config)?
423            }))
424        } else {
425            None
426        };
427        let get_device_for_tensor =
428            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
429
430        let vb = from_mmaped_safetensors(
431            $paths.get_weight_filenames().to_vec(),
432            Vec::new(),
433            $dtype,
434            $device,
435            $layer_devices,
436            $silent,
437            regexes,
438            |_| true, // Will be overwritten...
439            get_device_for_tensor,
440        )?;
441
442        $loader.load(
443            &$config,
444            $use_flash_attn,
445            vb,
446            $crate::pipeline::NormalLoadingMetadata {
447                mapper: $mapper,
448                loading_isq: $loading_isq,
449                real_device: $real_device,
450                multi_progress: $multi_progress,
451            },
452            $attention_mechanism,
453        )?
454    }};
455}
456
457#[doc(hidden)]
458#[macro_export]
459macro_rules! normal_model_loader_sharded {
460    (
461        $vb:expr,
462        $config:expr,
463        $loader:expr,
464        $use_flash_attn:expr,
465        $mapper:expr,
466        $loading_isq:expr,
467        $real_device:expr,
468        $attention_mechanism:expr,
469        $multi_progress:expr,
470    ) => {{
471        $loader.load(
472            &$config,
473            $use_flash_attn,
474            $vb,
475            $crate::pipeline::NormalLoadingMetadata {
476                mapper: $mapper,
477                loading_isq: $loading_isq,
478                real_device: $real_device,
479                multi_progress: $multi_progress,
480            },
481            $attention_mechanism,
482        )?
483    }};
484}
485
486#[doc(hidden)]
487#[macro_export]
488macro_rules! vision_normal_model_loader {
489    (
490        $paths:expr,
491        $dtype:expr,
492        $device:expr,
493        $layer_devices:expr,
494        $config:expr,
495        $loader:expr,
496        $use_flash_attn:expr,
497        $silent:expr,
498        $mapper:expr,
499        $loading_isq:expr,
500        $loading_uqff:expr,
501        $real_device:expr,
502        $attention_mechanism:expr,
503        $multi_progress:expr,
504    ) => {{
505        let regexes = if $loading_isq && $loading_uqff {
506            // Dummy weights for the layers which will be overwritten...
507            Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
508        } else {
509            None
510        };
511        let get_device_for_tensor =
512            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
513
514        let vb = from_mmaped_safetensors(
515            $paths.get_weight_filenames().to_vec(),
516            Vec::new(),
517            $dtype,
518            $device,
519            $layer_devices,
520            $silent,
521            regexes,
522            |_| true, // Will be overwritten...
523            get_device_for_tensor,
524        )?;
525
526        $loader.load(
527            &$config,
528            $use_flash_attn,
529            vb,
530            $crate::pipeline::NormalLoadingMetadata {
531                mapper: $mapper,
532                loading_isq: $loading_isq,
533                real_device: $real_device,
534                multi_progress: $multi_progress,
535            },
536            $attention_mechanism,
537        )?
538    }};
539}
540
541#[doc(hidden)]
542#[macro_export]
543macro_rules! vision_normal_model_loader_sharded {
544    (
545        $vb:expr,
546        $config:expr,
547        $loader:expr,
548        $use_flash_attn:expr,
549        $mapper:expr,
550        $loading_isq:expr,
551        $real_device:expr,
552        $attention_mechanism:expr,
553        $multi_progress:expr,
554    ) => {{
555        $loader.load(
556            &$config,
557            $use_flash_attn,
558            $vb,
559            $crate::pipeline::NormalLoadingMetadata {
560                mapper: $mapper,
561                loading_isq: $loading_isq,
562                real_device: $real_device,
563                multi_progress: $multi_progress,
564            },
565            $attention_mechanism,
566        )?
567    }};
568}
569
570#[doc(hidden)]
571#[macro_export]
572macro_rules! xlora_model_loader {
573    (
574        $paths:expr,
575        $dtype:expr,
576        $device:expr,
577        $layer_devices:expr,
578        $config:expr,
579        $loader:expr,
580        $use_flash_attn:expr,
581        $silent:expr,
582        $mapper:expr,
583        $loading_isq:expr,
584        $real_device:expr,
585        $multi_progress:expr,
586    ) => {{
587        let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
588        safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap());
589        let get_device_for_tensor =
590            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
591
592        let vb = from_mmaped_safetensors(
593            safetensors_paths
594                .iter()
595                .map(|x| (*x).to_owned())
596                .collect::<Vec<_>>(),
597            $paths
598                .get_adapter_filenames()
599                .as_ref()
600                .unwrap()
601                .iter()
602                .map(|(_, x)| (*x).to_owned())
603                .collect::<Vec<_>>(),
604            $dtype,
605            $device,
606            $layer_devices,
607            $silent,
608            None,
609            |_| true,
610            get_device_for_tensor,
611        )?;
612
613        $loader.load_xlora(
614            &$config,
615            $use_flash_attn,
616            vb,
617            $paths.get_adapter_configs().as_ref().unwrap(),
618            Some($paths.get_classifier_config().as_ref().unwrap().clone()),
619            $paths.get_ordering().as_ref().unwrap().clone(),
620            $crate::pipeline::NormalLoadingMetadata {
621                mapper: $mapper,
622                loading_isq: $loading_isq,
623                real_device: $real_device,
624                multi_progress: $multi_progress,
625            },
626            &None,
627        )?
628    }};
629}
630
631#[doc(hidden)]
632#[macro_export]
633macro_rules! lora_model_loader {
634    (
635        $paths:expr,
636        $dtype:expr,
637        $device:expr,
638        $layer_devices:expr,
639        $config:expr,
640        $loader:expr,
641        $use_flash_attn:expr,
642        $silent:expr,
643        $mapper:expr,
644        $loading_isq:expr,
645        $real_device:expr,
646        $multi_progress:expr,
647    ) => {{
648        let safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
649        let get_device_for_tensor =
650            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
651
652        let vb = from_mmaped_safetensors(
653            safetensors_paths
654                .iter()
655                .map(|x| (*x).to_owned())
656                .collect::<Vec<_>>(),
657            $paths
658                .get_adapter_filenames()
659                .as_ref()
660                .unwrap()
661                .iter()
662                .map(|(_, x)| (*x).to_owned())
663                .collect::<Vec<_>>(),
664            Some($dtype),
665            $device,
666            $layer_devices,
667            $silent,
668            None,
669            |_| true,
670            get_device_for_tensor,
671        )?;
672
673        $loader.load_xlora(
674            &$config,
675            $use_flash_attn,
676            vb,
677            $paths.get_adapter_configs().as_ref().unwrap(),
678            None,
679            $paths.get_ordering().as_ref().unwrap().clone(),
680            $crate::pipeline::NormalLoadingMetadata {
681                mapper: $mapper,
682                loading_isq: $loading_isq,
683                real_device: $real_device,
684                multi_progress: $multi_progress,
685            },
686            &$crate::utils::varbuilder_utils::load_preload_adapters(
687                $paths.get_lora_preload_adapter_info(),
688                $dtype,
689                $device,
690                $silent,
691            )?,
692        )?
693    }};
694}