1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4 ($api:expr, $model_id:expr, $should_panic:expr) => {
5 if std::path::Path::new($model_id).exists() {
6 let listing = std::fs::read_dir($model_id);
7 if listing.is_err() {
8 panic!("Cannot list directory {:?}", $model_id)
9 }
10 let listing = listing.unwrap();
11 listing
12 .into_iter()
13 .map(|s| {
14 s.unwrap()
15 .path()
16 .file_name()
17 .unwrap() .to_str()
19 .expect("Could not convert to str")
20 .to_string()
21 })
22 .collect::<Vec<String>>()
23 .into_iter()
24 } else {
25 let sanitized_id = std::path::Path::new($model_id)
26 .display()
27 .to_string()
28 .replace("/", "-");
29
30 let home_folder = if dirs::home_dir().is_some() {
31 let mut path = dirs::home_dir().unwrap();
32 path.push(".cache/huggingface/hub/");
33 if !path.exists() {
34 let _ = std::fs::create_dir_all(&path);
35 }
36 path
37 } else {
38 "./".into()
39 };
40
41 let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
42 .map(std::path::PathBuf::from)
43 .unwrap_or(home_folder.into());
44 let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
45 if std::path::Path::new(&cache_file).exists() {
46 use std::io::Read;
47 let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
49 let mut contents = String::new();
50 file.read_to_string(&mut contents)
51 .expect("Could not read cache file");
52 let cache: $crate::pipeline::FileListCache =
53 serde_json::from_str(&contents).expect("Could not parse cache JSON");
54 tracing::info!("Read from cache file {:?}", cache_file);
55 cache.files.into_iter()
56 } else {
57 $api.info()
58 .map(|repo| {
59 let files: Vec<String> = repo
60 .siblings
61 .iter()
62 .map(|x| x.rfilename.clone())
63 .collect::<Vec<String>>();
64 let cache = $crate::pipeline::FileListCache {
66 files: files.clone(),
67 };
68 let json = serde_json::to_string_pretty(&cache)
69 .expect("Could not serialize cache");
70 let ret = std::fs::write(&cache_file, json);
71 tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
72 files
73 })
74 .unwrap_or_else(|e| {
75 if $should_panic {
76 panic!("Could not get directory listing from API: {:?}", e)
77 } else {
78 tracing::warn!("Could not get directory listing from API: {:?}", e);
79 Vec::<String>::new()
80 }
81 })
82 .into_iter()
83 }
84 }
85 };
86}
87
88#[doc(hidden)]
89#[macro_export]
90macro_rules! api_get_file {
91 ($api:expr, $file:expr, $model_id:expr) => {
92 if std::path::Path::new($model_id).exists() {
93 let path = $model_id.join($file);
94 if !path.exists() {
95 panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
96 }
97 info!("Loading `{}` locally at `{}`", &$file, path.display());
98 path
99 } else {
100 $api.get($file)
101 .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
102 }
103 };
104}
105
106#[doc(hidden)]
107#[macro_export]
108macro_rules! get_paths {
109 (
110 $path_name:ident,
111 $token_source:expr,
112 $revision:expr,
113 $this:expr,
114 $quantized_model_id:expr,
115 $quantized_filename:expr,
116 $silent:expr,
117 $loading_uqff:expr
118 ) => {{
119 let api = {
120 use $crate::GLOBAL_HF_CACHE;
121 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
122 let mut api = ApiBuilder::from_cache(cache)
123 .with_progress(!$silent)
124 .with_token(get_token($token_source)?);
125 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
126 api = api.with_cache_dir(x.into());
127 }
128 api.build()?
129 };
130 let revision = $revision.unwrap_or("main".to_string());
131 let api = api.repo(Repo::with_revision(
132 $this.model_id.clone(),
133 RepoType::Model,
134 revision.clone(),
135 ));
136 let model_id = std::path::Path::new(&$this.model_id);
137 let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
138 info!("Using tokenizer.json at `{p}`");
139 PathBuf::from_str(p)?
140 } else {
141 info!("Loading `tokenizer.json` at `{}`", $this.model_id);
142 $crate::api_get_file!(api, "tokenizer.json", model_id)
143 };
144 info!("Loading `config.json` at `{}`", $this.model_id);
145 let config_filename = $crate::api_get_file!(api, "config.json", model_id);
146 let filenames = get_model_paths(
147 revision.clone(),
148 &$token_source,
149 $quantized_model_id.as_ref(),
150 $quantized_filename.as_ref(),
151 &api,
152 &model_id,
153 $loading_uqff,
154 )?;
155 let adapter_paths = get_xlora_paths(
156 $this.model_id.clone(),
157 $this.xlora_model_id.as_ref(),
158 $this.lora_adapter_ids.as_ref(),
159 &$token_source,
160 revision.clone(),
161 $this.xlora_order.as_ref(),
162 )?;
163 let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();
164
165 let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
166 info!("Loading `generation_config.json` at `{}`", $this.model_id);
167 Some($crate::api_get_file!(
168 api,
169 "generation_config.json",
170 model_id
171 ))
172 } else {
173 None
174 };
175 let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
176 info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
177 Some($crate::api_get_file!(
178 api,
179 "preprocessor_config.json",
180 model_id
181 ))
182 } else {
183 None
184 };
185 let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
186 info!("Loading `processor_config.json` at `{}`", $this.model_id);
187 Some($crate::api_get_file!(
188 api,
189 "processor_config.json",
190 model_id
191 ))
192 } else {
193 None
194 };
195 let template_filename = if let Some(ref p) = $this.chat_template {
196 info!("Using chat template file at `{p}`");
197 Some(PathBuf::from_str(p)?)
198 } else if dir_list.contains(&"chat_template.jinja".to_string()) {
199 info!("Loading `chat_template.jinja` at `{}`", $this.model_id);
200 Some($crate::api_get_file!(api, "chat_template.jinja", model_id))
201 } else {
202 info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
203 Some($crate::api_get_file!(
204 api,
205 "tokenizer_config.json",
206 model_id
207 ))
208 };
209 let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
210 info!("Loading `chat_template.json` at `{}`", $this.model_id);
211 Some($crate::api_get_file!(api, "chat_template.json", model_id))
212 } else {
213 None
214 };
215 Ok(Box::new($path_name {
216 tokenizer_filename,
217 config_filename,
218 filenames,
219 adapter_paths,
220 template_filename,
221 gen_conf,
222 preprocessor_config,
223 processor_config,
224 chat_template_json_filename,
225 }))
226 }};
227}
228
229#[doc(hidden)]
230#[macro_export]
231macro_rules! get_uqff_paths {
232 ($from_uqff:expr, $this:expr, $silent:expr) => {{
233 let api = {
234 use $crate::GLOBAL_HF_CACHE;
235 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
236 let mut api = ApiBuilder::from_cache(cache)
237 .with_progress(!$silent)
238 .with_token(get_token(
239 &$this
240 .token_source
241 .read()
242 .expect("Failed to read token source")
243 .clone()
244 .unwrap_or(TokenSource::None),
245 )?);
246 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
247 api = api.with_cache_dir(x.into());
248 }
249 api.build()?
250 };
251 let revision = $this
252 .revision
253 .read()
254 .expect("Failed to read revision")
255 .clone()
256 .unwrap_or("main".to_string());
257 let api = api.repo(Repo::with_revision(
258 $this.model_id.to_string(),
259 RepoType::Model,
260 revision.clone(),
261 ));
262
263 let mut files = Vec::new();
264 for file in $from_uqff {
265 let file = file.display().to_string();
266
267 files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
268 }
269 files
270 }};
271}
272
273#[doc(hidden)]
274#[macro_export]
275macro_rules! get_paths_gguf {
276 (
277 $path_name:ident,
278 $token_source:expr,
279 $revision:expr,
280 $this:expr,
281 $quantized_model_id:expr,
282 $quantized_filenames:expr,
283 $silent:expr
284 ) => {{
285 let api = {
286 use $crate::GLOBAL_HF_CACHE;
287 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
288 let mut api = ApiBuilder::from_cache(cache)
289 .with_progress(!$silent)
290 .with_token(get_token($token_source)?);
291 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
292 api = api.with_cache_dir(x.into());
293 }
294 api.build()?
295 };
296 let revision = $revision.unwrap_or("main".to_string());
297 let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
298 let api = api.repo(Repo::with_revision(
299 this_model_id.clone(),
300 RepoType::Model,
301 revision.clone(),
302 ));
303 let model_id = std::path::Path::new(&this_model_id);
304
305 let dir_list = $crate::api_dir_list!(api, model_id, false)
306 .collect::<Vec<_>>();
307
308 let chat_template = if let Some(ref p) = $this.chat_template {
309 if p.ends_with(".json") || p.ends_with(".jinja") {
310 info!("Using chat template file at `{p}`");
311 Some(PathBuf::from_str(p)?)
312 } else {
313 panic!("Specified chat template file must end with .json or .jinja");
314 }
315 } else {
316 if $this.model_id.is_none() {
317 None
318 } else if dir_list.contains(&"chat_template.jinja".to_string()) {
319 info!("Loading `chat_template.jinja` at `{}`", this_model_id);
320 Some($crate::api_get_file!(
321 api,
322 "chat_template.jinja",
323 model_id
324 ))
325 } else {
326 info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
327 let res = $crate::api_get_file!(
328 api,
329 "tokenizer_config.json",
330 model_id
331 );
332 Some(res)
333 }
334 };
335
336 let filenames = get_model_paths(
337 revision.clone(),
338 &$token_source,
339 Some(&$quantized_model_id),
340 Some(&$quantized_filenames),
341 &api,
342 &model_id,
343 false, )?;
345
346 info!("GGUF file(s) {:?}", filenames);
347 let adapter_paths = get_xlora_paths(
348 this_model_id.clone(),
349 $this.xlora_model_id.as_ref(),
350 $this.lora_adapter_ids.as_ref(),
351 &$token_source,
352 revision.clone(),
353 $this.xlora_order.as_ref(),
354 )?;
355
356 let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
357 info!("Loading `generation_config.json` at `{}`", this_model_id);
358 Some($crate::api_get_file!(
359 api,
360 "generation_config.json",
361 model_id
362 ))
363 } else {
364 None
365 };
366
367 let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string())
368 {
369 info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
370 Some($crate::api_get_file!(
371 api,
372 "preprocessor_config.json",
373 model_id
374 ))
375 } else {
376 None
377 };
378
379 let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
380 info!("Loading `processor_config.json` at `{}`", this_model_id);
381 Some($crate::api_get_file!(
382 api,
383 "processor_config.json",
384 model_id
385 ))
386 } else {
387 None
388 };
389
390 let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) {
391 info!("Loading `tokenizer.json` at `{}`", this_model_id);
392 $crate::api_get_file!(api, "tokenizer.json", model_id)
393 } else {
394 PathBuf::from_str("")?
395 };
396
397 let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
398 info!("Loading `chat_template.json` at `{}`", this_model_id);
399 Some($crate::api_get_file!(
400 api,
401 "chat_template.json",
402 model_id
403 ))
404 } else {
405 None
406 };
407
408 Ok(Box::new($path_name {
409 tokenizer_filename,
410 config_filename: PathBuf::from_str("")?,
411 filenames,
412 adapter_paths,
413 template_filename: chat_template,
414 gen_conf,
415 preprocessor_config,
416 processor_config,
417 chat_template_json_filename,
418 }))
419 }};
420}
421
422#[doc(hidden)]
423#[macro_export]
424macro_rules! normal_model_loader {
425 (
426 $paths:expr,
427 $dtype:expr,
428 $device:expr,
429 $layer_devices:expr,
430 $config:expr,
431 $loader:expr,
432 $silent:expr,
433 $mapper:expr,
434 $loading_isq:expr,
435 $loading_uqff:expr,
436 $real_device:expr,
437 $attention_mechanism:expr,
438 $is_moqe:expr,
439 $multi_progress:expr,
440 $matformer_config:expr,
441 ) => {{
442 let regexes = if $loading_isq && $loading_uqff {
443 Some(std::sync::Arc::new(if $is_moqe {
445 $loader.isq_layer_regexes_moqe(&$config)?
446 } else {
447 $loader.isq_layer_regexes(&$config)?
448 }))
449 } else {
450 None
451 };
452 let get_device_for_tensor =
453 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
454
455 let vb = from_mmaped_safetensors(
456 $paths.get_weight_filenames().to_vec(),
457 Vec::new(),
458 $dtype,
459 $device,
460 $layer_devices,
461 $silent,
462 regexes,
463 |_| true, get_device_for_tensor,
465 )?;
466
467 $loader.load(
468 &$config,
469 vb,
470 $crate::pipeline::NormalLoadingMetadata {
471 mapper: $mapper,
472 loading_isq: $loading_isq,
473 real_device: $real_device,
474 multi_progress: $multi_progress,
475 matformer_slicing_config: $matformer_config,
476 },
477 $attention_mechanism,
478 )?
479 }};
480}
481
482#[doc(hidden)]
483#[macro_export]
484macro_rules! normal_model_loader_sharded {
485 (
486 $vb:expr,
487 $config:expr,
488 $loader:expr,
489 $mapper:expr,
490 $loading_isq:expr,
491 $real_device:expr,
492 $attention_mechanism:expr,
493 $multi_progress:expr,
494 $matformer_config:expr,
495 ) => {{
496 $loader.load(
497 &$config,
498 $vb,
499 $crate::pipeline::NormalLoadingMetadata {
500 mapper: $mapper,
501 loading_isq: $loading_isq,
502 real_device: $real_device,
503 multi_progress: $multi_progress,
504 matformer_slicing_config: $matformer_config,
505 },
506 $attention_mechanism,
507 )?
508 }};
509}
510
511#[doc(hidden)]
512#[macro_export]
513macro_rules! vision_normal_model_loader {
514 (
515 $paths:expr,
516 $dtype:expr,
517 $device:expr,
518 $layer_devices:expr,
519 $config:expr,
520 $loader:expr,
521 $silent:expr,
522 $mapper:expr,
523 $loading_isq:expr,
524 $loading_uqff:expr,
525 $real_device:expr,
526 $attention_mechanism:expr,
527 $multi_progress:expr,
528 $matformer_config:expr,
529 ) => {{
530 let regexes = if $loading_isq && $loading_uqff {
531 Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
533 } else {
534 None
535 };
536 let get_device_for_tensor =
537 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
538
539 let vb = from_mmaped_safetensors(
540 $paths.get_weight_filenames().to_vec(),
541 Vec::new(),
542 $dtype,
543 $device,
544 $layer_devices,
545 $silent,
546 regexes,
547 |_| true, get_device_for_tensor,
549 )?;
550
551 $loader.load(
552 &$config,
553 vb,
554 $crate::pipeline::NormalLoadingMetadata {
555 mapper: $mapper,
556 loading_isq: $loading_isq,
557 real_device: $real_device,
558 multi_progress: $multi_progress,
559 matformer_slicing_config: $matformer_config,
560 },
561 $attention_mechanism,
562 )?
563 }};
564}
565
566#[doc(hidden)]
567#[macro_export]
568macro_rules! vision_normal_model_loader_sharded {
569 (
570 $vb:expr,
571 $config:expr,
572 $loader:expr,
573 $mapper:expr,
574 $loading_isq:expr,
575 $real_device:expr,
576 $attention_mechanism:expr,
577 $multi_progress:expr,
578 $matformer_config:expr,
579 ) => {{
580 $loader.load(
581 &$config,
582 $vb,
583 $crate::pipeline::NormalLoadingMetadata {
584 mapper: $mapper,
585 loading_isq: $loading_isq,
586 real_device: $real_device,
587 multi_progress: $multi_progress,
588 matformer_slicing_config: $matformer_config,
589 },
590 $attention_mechanism,
591 )?
592 }};
593}
594
595#[doc(hidden)]
596#[macro_export]
597macro_rules! xlora_model_loader {
598 (
599 $paths:expr,
600 $dtype:expr,
601 $device:expr,
602 $layer_devices:expr,
603 $config:expr,
604 $loader:expr,
605 $silent:expr,
606 $mapper:expr,
607 $loading_isq:expr,
608 $real_device:expr,
609 $multi_progress:expr,
610 $matformer_config:expr,
611 ) => {{
612 let $crate::pipeline::AdapterPaths::XLora {
614 adapter_configs,
615 adapter_safetensors,
616 classifier_path,
617 xlora_order,
618 xlora_config,
619 lora_preload_adapter_info: _,
620 } = $paths.get_adapter_paths()
621 else {
622 unreachable!()
623 };
624
625 let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
626 safetensors_paths.push(classifier_path.as_ref().unwrap());
627 let get_device_for_tensor =
628 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
629
630 let vb = from_mmaped_safetensors(
631 safetensors_paths
632 .iter()
633 .map(|x| (*x).to_owned())
634 .collect::<Vec<_>>(),
635 adapter_safetensors
636 .as_ref()
637 .unwrap()
638 .iter()
639 .map(|(_, x)| (*x).to_owned())
640 .collect::<Vec<_>>(),
641 $dtype,
642 $device,
643 $layer_devices,
644 $silent,
645 None,
646 |_| true,
647 get_device_for_tensor,
648 )?;
649
650 $loader.load_xlora(
651 &$config,
652 vb,
653 adapter_configs.as_ref().unwrap(),
654 Some(xlora_config.as_ref().unwrap().clone()),
655 xlora_order.as_ref().unwrap().clone(),
656 $crate::pipeline::NormalLoadingMetadata {
657 mapper: $mapper,
658 loading_isq: $loading_isq,
659 real_device: $real_device,
660 multi_progress: $multi_progress,
661 matformer_slicing_config: $matformer_config,
662 },
663 &None,
664 )?
665 }};
666}
667
668#[doc(hidden)]
669#[macro_export]
670macro_rules! lora_model_loader {
671 (
672 $paths:expr,
673 $dtype:expr,
674 $device:expr,
675 $layer_devices:expr,
676 $config:expr,
677 $loader:expr,
678 $silent:expr,
679 $mapper:expr,
680 $loading_isq:expr,
681 $loading_uqff:expr,
682 $real_device:expr,
683 $attention_mechanism:expr,
684 $is_moqe:expr,
685 $multi_progress:expr,
686 $matformer_config:expr,
687 ) => {{
688 let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
689 else {
690 unreachable!()
691 };
692
693 let regexes = if $loading_isq && $loading_uqff {
694 Some(std::sync::Arc::new(if $is_moqe {
696 $loader.isq_layer_regexes_moqe(&$config)?
697 } else {
698 $loader.isq_layer_regexes(&$config)?
699 }))
700 } else {
701 None
702 };
703 let get_device_for_tensor =
704 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
705
706 let vb = from_mmaped_safetensors(
707 $paths.get_weight_filenames().to_vec(),
708 Vec::new(),
709 $dtype,
710 $device,
711 $layer_devices,
712 $silent,
713 regexes,
714 |_| true, get_device_for_tensor.clone(),
716 )?;
717
718 for $crate::pipeline::LoraAdapterPaths {
719 adapter_path,
720 lora_config,
721 } in lora_adapter_paths
722 {
723 let lora_vb = from_mmaped_safetensors(
724 vec![adapter_path.clone()],
725 Vec::new(),
726 $dtype,
727 $device,
728 $layer_devices,
729 $silent,
730 None,
731 |_| true,
732 get_device_for_tensor.clone(),
733 )?;
734
735 mistralrs_quant::push_applied_lora(mistralrs_quant::LoraAdapter {
736 config: lora_config.clone(),
737 weights: lora_vb,
738 });
739 }
740
741 $loader.load(
742 &$config,
743 vb,
744 $crate::pipeline::NormalLoadingMetadata {
745 mapper: $mapper,
746 loading_isq: $loading_isq,
747 real_device: $real_device,
748 multi_progress: $multi_progress,
749 matformer_slicing_config: $matformer_config,
750 },
751 $attention_mechanism,
752 )?
753 }};
754}