mistralrs_core/pipeline/
macros.rs

1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4    ($api:expr, $model_id:expr) => {
5        if std::path::Path::new($model_id).exists() {
6            let listing = std::fs::read_dir($model_id);
7            if listing.is_err() {
8                panic!("Cannot list directory {:?}", $model_id)
9            }
10            let listing = listing.unwrap();
11            listing
12                .into_iter()
13                .map(|s| {
14                    s.unwrap()
15                        .path()
16                        .file_name()
17                        .unwrap() // Should never terminate in `..`
18                        .to_str()
19                        .expect("Could not convert to str")
20                        .to_string()
21                })
22                .collect::<Vec<String>>()
23                .into_iter()
24        } else {
25            $api.info()
26                .map(|repo| {
27                    repo.siblings
28                        .iter()
29                        .map(|x| x.rfilename.clone())
30                        .collect::<Vec<String>>()
31                })
32                .unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e))
33                .into_iter()
34        }
35    };
36}
37
38#[doc(hidden)]
39#[macro_export]
40macro_rules! api_get_file {
41    ($api:expr, $file:expr, $model_id:expr) => {
42        if std::path::Path::new($model_id).exists() {
43            let path = $model_id.join($file);
44            if !path.exists() {
45                panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
46            }
47            info!("Loading `{}` locally at `{}`", &$file, path.display());
48            path
49        } else {
50            $api.get($file)
51                .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
52        }
53    };
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! get_paths {
59    (
60        $path_name:ident,
61        $token_source:expr,
62        $revision:expr,
63        $this:expr,
64        $quantized_model_id:expr,
65        $quantized_filename:expr,
66        $silent:expr,
67        $loading_uqff:expr
68    ) => {{
69        let api = {
70            use $crate::GLOBAL_HF_CACHE;
71            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
72            let mut api = ApiBuilder::from_cache(cache)
73                .with_progress(!$silent)
74                .with_token(get_token($token_source)?);
75            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
76                api = api.with_cache_dir(x.into());
77            }
78            api.build()?
79        };
80        let revision = $revision.unwrap_or("main".to_string());
81        let api = api.repo(Repo::with_revision(
82            $this.model_id.clone(),
83            RepoType::Model,
84            revision.clone(),
85        ));
86        let model_id = std::path::Path::new(&$this.model_id);
87        let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
88            info!("Using tokenizer.json at `{p}`");
89            PathBuf::from_str(p)?
90        } else {
91            info!("Loading `tokenizer.json` at `{}`", $this.model_id);
92            $crate::api_get_file!(api, "tokenizer.json", model_id)
93        };
94        info!("Loading `config.json` at `{}`", $this.model_id);
95        let config_filename = $crate::api_get_file!(api, "config.json", model_id);
96        let filenames = get_model_paths(
97            revision.clone(),
98            &$token_source,
99            &$quantized_model_id,
100            &$quantized_filename,
101            &api,
102            &model_id,
103            $loading_uqff,
104        )?;
105        let adapter_paths = get_xlora_paths(
106            $this.model_id.clone(),
107            &$this.xlora_model_id,
108            &$this.lora_adapter_ids,
109            &$token_source,
110            revision.clone(),
111            &$this.xlora_order,
112        )?;
113        let gen_conf = if $crate::api_dir_list!(api, model_id)
114            .collect::<Vec<_>>()
115            .contains(&"generation_config.json".to_string())
116        {
117            info!("Loading `generation_config.json` at `{}`", $this.model_id);
118            Some($crate::api_get_file!(
119                api,
120                "generation_config.json",
121                model_id
122            ))
123        } else {
124            None
125        };
126        let preprocessor_config = if $crate::api_dir_list!(api, model_id)
127            .collect::<Vec<_>>()
128            .contains(&"preprocessor_config.json".to_string())
129        {
130            info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
131            Some($crate::api_get_file!(
132                api,
133                "preprocessor_config.json",
134                model_id
135            ))
136        } else {
137            None
138        };
139        let processor_config = if $crate::api_dir_list!(api, model_id)
140            .collect::<Vec<_>>()
141            .contains(&"processor_config.json".to_string())
142        {
143            info!("Loading `processor_config.json` at `{}`", $this.model_id);
144            Some($crate::api_get_file!(
145                api,
146                "processor_config.json",
147                model_id
148            ))
149        } else {
150            None
151        };
152        let template_filename = if let Some(ref p) = $this.chat_template {
153            info!("Using chat template file at `{p}`");
154            Some(PathBuf::from_str(p)?)
155        } else {
156            info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
157            Some($crate::api_get_file!(
158                api,
159                "tokenizer_config.json",
160                model_id
161            ))
162        };
163        let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
164            .collect::<Vec<_>>()
165            .contains(&"chat_template.json".to_string())
166        {
167            info!("Loading `chat_template.json` at `{}`", $this.model_id);
168            Some($crate::api_get_file!(api, "chat_template.json", model_id))
169        } else {
170            None
171        };
172        Ok(Box::new($path_name {
173            tokenizer_filename,
174            config_filename,
175            filenames,
176            adapter_paths,
177            template_filename,
178            gen_conf,
179            preprocessor_config,
180            processor_config,
181            chat_template_json_filename,
182        }))
183    }};
184}
185
186#[doc(hidden)]
187#[macro_export]
188macro_rules! get_uqff_paths {
189    ($from_uqff:expr, $this:expr, $silent:expr) => {{
190        let api = {
191            use $crate::GLOBAL_HF_CACHE;
192            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
193            let mut api = ApiBuilder::from_cache(cache)
194                .with_progress(!$silent)
195                .with_token(get_token(
196                    &$this
197                        .token_source
198                        .read()
199                        .expect("Failed to read token source")
200                        .clone()
201                        .unwrap_or(TokenSource::None),
202                )?);
203            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
204                api = api.with_cache_dir(x.into());
205            }
206            api.build()?
207        };
208        let revision = $this
209            .revision
210            .read()
211            .expect("Failed to read revision")
212            .clone()
213            .unwrap_or("main".to_string());
214        let api = api.repo(Repo::with_revision(
215            $this.model_id.to_string(),
216            RepoType::Model,
217            revision.clone(),
218        ));
219
220        let mut files = Vec::new();
221        for file in $from_uqff {
222            let file = file.display().to_string();
223
224            files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
225        }
226        files
227    }};
228}
229
230#[doc(hidden)]
231#[macro_export]
232macro_rules! get_paths_gguf {
233    (
234        $path_name:ident,
235        $token_source:expr,
236        $revision:expr,
237        $this:expr,
238        $quantized_model_id:expr,
239        $quantized_filenames:expr,
240        $silent:expr
241    ) => {{
242        let api = {
243            use $crate::GLOBAL_HF_CACHE;
244            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
245            let mut api = ApiBuilder::from_cache(cache)
246                .with_progress(!$silent)
247                .with_token(get_token($token_source)?);
248            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
249                api = api.with_cache_dir(x.into());
250            }
251            api.build()?
252        };
253        let revision = $revision.unwrap_or("main".to_string());
254        let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
255        let api = api.repo(Repo::with_revision(
256            this_model_id.clone(),
257            RepoType::Model,
258            revision.clone(),
259        ));
260        let model_id = std::path::Path::new(&this_model_id);
261
262        let chat_template = if let Some(ref p) = $this.chat_template {
263            if p.ends_with(".json") {
264                info!("Using chat template file at `{p}`");
265                Some(PathBuf::from_str(p)?)
266            } else {
267                panic!("Specified chat template file must end with .json");
268            }
269        } else {
270            if $this.model_id.is_none() {
271                None
272            } else {
273                info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
274                let res = $crate::api_get_file!(
275                    api,
276                    "tokenizer_config.json",
277                    model_id
278                );
279                Some(res)
280            }
281        };
282
283        let filenames = get_model_paths(
284            revision.clone(),
285            &$token_source,
286            &Some($quantized_model_id),
287            &Some($quantized_filenames),
288            &api,
289            &model_id,
290            false, // Never loading UQFF
291        )?;
292
293        let adapter_paths = get_xlora_paths(
294            this_model_id.clone(),
295            &$this.xlora_model_id,
296            &$this.lora_adapter_ids,
297            &$token_source,
298            revision.clone(),
299            &$this.xlora_order,
300        )?;
301
302        let gen_conf = if $crate::api_dir_list!(api, model_id)
303            .collect::<Vec<_>>()
304            .contains(&"generation_config.json".to_string())
305        {
306            info!("Loading `generation_config.json` at `{}`", this_model_id);
307            Some($crate::api_get_file!(
308                api,
309                "generation_config.json",
310                model_id
311            ))
312        } else {
313            None
314        };
315
316        let preprocessor_config = if $crate::api_dir_list!(api, model_id)
317            .collect::<Vec<_>>()
318            .contains(&"preprocessor_config.json".to_string())
319        {
320            info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
321            Some($crate::api_get_file!(
322                api,
323                "preprocessor_config.json",
324                model_id
325            ))
326        } else {
327            None
328        };
329
330        let processor_config = if $crate::api_dir_list!(api, model_id)
331            .collect::<Vec<_>>()
332            .contains(&"processor_config.json".to_string())
333        {
334            info!("Loading `processor_config.json` at `{}`", this_model_id);
335            Some($crate::api_get_file!(
336                api,
337                "processor_config.json",
338                model_id
339            ))
340        } else {
341            None
342        };
343
344        let tokenizer_filename = if $this.model_id.is_some() {
345            info!("Loading `tokenizer.json` at `{}`", this_model_id);
346            $crate::api_get_file!(api, "tokenizer.json", model_id)
347        } else {
348            PathBuf::from_str("")?
349        };
350
351        let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
352            .collect::<Vec<_>>()
353            .contains(&"chat_template.json".to_string())
354        {
355            info!("Loading `chat_template.json` at `{}`", this_model_id);
356            Some($crate::api_get_file!(
357                api,
358                "chat_template.json",
359                model_id
360            ))
361        } else {
362            None
363        };
364
365        Ok(Box::new($path_name {
366            tokenizer_filename,
367            config_filename: PathBuf::from_str("")?,
368            filenames,
369            adapter_paths,
370            template_filename: chat_template,
371            gen_conf,
372            preprocessor_config,
373            processor_config,
374            chat_template_json_filename,
375        }))
376    }};
377}
378
379#[doc(hidden)]
380#[macro_export]
381macro_rules! normal_model_loader {
382    (
383        $paths:expr,
384        $dtype:expr,
385        $device:expr,
386        $layer_devices:expr,
387        $config:expr,
388        $loader:expr,
389        $use_flash_attn:expr,
390        $silent:expr,
391        $mapper:expr,
392        $loading_isq:expr,
393        $loading_uqff:expr,
394        $real_device:expr,
395        $attention_mechanism:expr,
396        $is_moqe:expr,
397        $multi_progress:expr,
398    ) => {{
399        let regexes = if $loading_isq && $loading_uqff {
400            // Dummy weights for the layers which will be overwritten...
401            Some(std::sync::Arc::new(if $is_moqe {
402                $loader.isq_layer_regexes_moqe(&$config)?
403            } else {
404                $loader.isq_layer_regexes(&$config)?
405            }))
406        } else {
407            None
408        };
409        let get_device_for_tensor =
410            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
411
412        let vb = from_mmaped_safetensors(
413            $paths.get_weight_filenames().to_vec(),
414            Vec::new(),
415            $dtype,
416            $device,
417            $layer_devices,
418            $silent,
419            regexes,
420            |_| true, // Will be overwritten...
421            get_device_for_tensor,
422        )?;
423
424        $loader.load(
425            &$config,
426            $use_flash_attn,
427            vb,
428            $crate::pipeline::NormalLoadingMetadata {
429                mapper: $mapper,
430                loading_isq: $loading_isq,
431                real_device: $real_device,
432                multi_progress: $multi_progress,
433            },
434            $attention_mechanism,
435        )?
436    }};
437}
438
439#[doc(hidden)]
440#[macro_export]
441macro_rules! normal_model_loader_sharded {
442    (
443        $vb:expr,
444        $config:expr,
445        $loader:expr,
446        $use_flash_attn:expr,
447        $mapper:expr,
448        $loading_isq:expr,
449        $real_device:expr,
450        $attention_mechanism:expr,
451        $multi_progress:expr,
452    ) => {{
453        $loader.load(
454            &$config,
455            $use_flash_attn,
456            $vb,
457            $crate::pipeline::NormalLoadingMetadata {
458                mapper: $mapper,
459                loading_isq: $loading_isq,
460                real_device: $real_device,
461                multi_progress: $multi_progress,
462            },
463            $attention_mechanism,
464        )?
465    }};
466}
467
468#[doc(hidden)]
469#[macro_export]
470macro_rules! vision_normal_model_loader {
471    (
472        $paths:expr,
473        $dtype:expr,
474        $device:expr,
475        $layer_devices:expr,
476        $config:expr,
477        $loader:expr,
478        $use_flash_attn:expr,
479        $silent:expr,
480        $mapper:expr,
481        $loading_isq:expr,
482        $loading_uqff:expr,
483        $real_device:expr,
484        $attention_mechanism:expr,
485        $multi_progress:expr,
486    ) => {{
487        let regexes = if $loading_isq && $loading_uqff {
488            // Dummy weights for the layers which will be overwritten...
489            Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
490        } else {
491            None
492        };
493        let get_device_for_tensor =
494            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
495
496        let vb = from_mmaped_safetensors(
497            $paths.get_weight_filenames().to_vec(),
498            Vec::new(),
499            $dtype,
500            $device,
501            $layer_devices,
502            $silent,
503            regexes,
504            |_| true, // Will be overwritten...
505            get_device_for_tensor,
506        )?;
507
508        $loader.load(
509            &$config,
510            $use_flash_attn,
511            vb,
512            $crate::pipeline::NormalLoadingMetadata {
513                mapper: $mapper,
514                loading_isq: $loading_isq,
515                real_device: $real_device,
516                multi_progress: $multi_progress,
517            },
518            $attention_mechanism,
519        )?
520    }};
521}
522
523#[doc(hidden)]
524#[macro_export]
525macro_rules! vision_normal_model_loader_sharded {
526    (
527        $vb:expr,
528        $config:expr,
529        $loader:expr,
530        $use_flash_attn:expr,
531        $mapper:expr,
532        $loading_isq:expr,
533        $real_device:expr,
534        $attention_mechanism:expr,
535        $multi_progress:expr,
536    ) => {{
537        $loader.load(
538            &$config,
539            $use_flash_attn,
540            $vb,
541            $crate::pipeline::NormalLoadingMetadata {
542                mapper: $mapper,
543                loading_isq: $loading_isq,
544                real_device: $real_device,
545                multi_progress: $multi_progress,
546            },
547            $attention_mechanism,
548        )?
549    }};
550}
551
552#[doc(hidden)]
553#[macro_export]
554macro_rules! xlora_model_loader {
555    (
556        $paths:expr,
557        $dtype:expr,
558        $device:expr,
559        $layer_devices:expr,
560        $config:expr,
561        $loader:expr,
562        $use_flash_attn:expr,
563        $silent:expr,
564        $mapper:expr,
565        $loading_isq:expr,
566        $real_device:expr,
567        $multi_progress:expr,
568    ) => {{
569        // TODO: remove lora_preload_adapter_info
570        let $crate::pipeline::AdapterPaths::XLora {
571            adapter_configs,
572            adapter_safetensors,
573            classifier_path,
574            xlora_order,
575            xlora_config,
576            lora_preload_adapter_info: _,
577        } = $paths.get_adapter_paths()
578        else {
579            unreachable!()
580        };
581
582        let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
583        safetensors_paths.push(classifier_path.as_ref().unwrap());
584        let get_device_for_tensor =
585            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
586
587        let vb = from_mmaped_safetensors(
588            safetensors_paths
589                .iter()
590                .map(|x| (*x).to_owned())
591                .collect::<Vec<_>>(),
592            adapter_safetensors
593                .as_ref()
594                .unwrap()
595                .iter()
596                .map(|(_, x)| (*x).to_owned())
597                .collect::<Vec<_>>(),
598            $dtype,
599            $device,
600            $layer_devices,
601            $silent,
602            None,
603            |_| true,
604            get_device_for_tensor,
605        )?;
606
607        $loader.load_xlora(
608            &$config,
609            $use_flash_attn,
610            vb,
611            adapter_configs.as_ref().unwrap(),
612            Some(xlora_config.as_ref().unwrap().clone()),
613            xlora_order.as_ref().unwrap().clone(),
614            $crate::pipeline::NormalLoadingMetadata {
615                mapper: $mapper,
616                loading_isq: $loading_isq,
617                real_device: $real_device,
618                multi_progress: $multi_progress,
619            },
620            &None,
621        )?
622    }};
623}
624
625#[doc(hidden)]
626#[macro_export]
627macro_rules! lora_model_loader {
628    (
629        $paths:expr,
630        $dtype:expr,
631        $device:expr,
632        $layer_devices:expr,
633        $config:expr,
634        $loader:expr,
635        $use_flash_attn:expr,
636        $silent:expr,
637        $mapper:expr,
638        $loading_isq:expr,
639        $loading_uqff:expr,
640        $real_device:expr,
641        $attention_mechanism:expr,
642        $is_moqe:expr,
643        $multi_progress:expr,
644    ) => {{
645        let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
646        else {
647            unreachable!()
648        };
649
650        let regexes = if $loading_isq && $loading_uqff {
651            // Dummy weights for the layers which will be overwritten...
652            Some(std::sync::Arc::new(if $is_moqe {
653                $loader.isq_layer_regexes_moqe(&$config)?
654            } else {
655                $loader.isq_layer_regexes(&$config)?
656            }))
657        } else {
658            None
659        };
660        let get_device_for_tensor =
661            $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
662
663        let vb = from_mmaped_safetensors(
664            $paths.get_weight_filenames().to_vec(),
665            Vec::new(),
666            $dtype,
667            $device,
668            $layer_devices,
669            $silent,
670            regexes,
671            |_| true, // Will be overwritten...
672            get_device_for_tensor.clone(),
673        )?;
674
675        for $crate::pipeline::LoraAdapterPaths {
676            adapter_path,
677            lora_config,
678        } in lora_adapter_paths
679        {
680            let lora_vb = from_mmaped_safetensors(
681                vec![adapter_path.clone()],
682                Vec::new(),
683                $dtype,
684                $device,
685                $layer_devices,
686                $silent,
687                None,
688                |_| true,
689                get_device_for_tensor.clone(),
690            )?;
691
692            mistralrs_quant::APPLIED_LORAS
693                .lock()
694                .unwrap()
695                .push(mistralrs_quant::LoraAdapter {
696                    config: lora_config.clone(),
697                    weights: lora_vb,
698                });
699        }
700
701        $loader.load(
702            &$config,
703            $use_flash_attn,
704            vb,
705            $crate::pipeline::NormalLoadingMetadata {
706                mapper: $mapper,
707                loading_isq: $loading_isq,
708                real_device: $real_device,
709                multi_progress: $multi_progress,
710            },
711            $attention_mechanism,
712        )?
713    }};
714}