1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4 ($api:expr, $model_id:expr) => {
5 if std::path::Path::new($model_id).exists() {
6 let listing = std::fs::read_dir($model_id);
7 if listing.is_err() {
8 panic!("Cannot list directory {:?}", $model_id)
9 }
10 let listing = listing.unwrap();
11 listing
12 .into_iter()
13 .map(|s| {
14 s.unwrap()
15 .path()
16 .file_name()
17 .unwrap() .to_str()
19 .expect("Could not convert to str")
20 .to_string()
21 })
22 .collect::<Vec<String>>()
23 .into_iter()
24 } else {
25 $api.info()
26 .map(|repo| {
27 repo.siblings
28 .iter()
29 .map(|x| x.rfilename.clone())
30 .collect::<Vec<String>>()
31 })
32 .unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e))
33 .into_iter()
34 }
35 };
36}
37
38#[doc(hidden)]
39#[macro_export]
40macro_rules! api_get_file {
41 ($api:expr, $file:expr, $model_id:expr) => {
42 if std::path::Path::new($model_id).exists() {
43 let path = $model_id.join($file);
44 if !path.exists() {
45 panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
46 }
47 info!("Loading `{}` locally at `{}`", &$file, path.display());
48 path
49 } else {
50 $api.get($file)
51 .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
52 }
53 };
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! get_paths {
59 (
60 $path_name:ident,
61 $token_source:expr,
62 $revision:expr,
63 $this:expr,
64 $quantized_model_id:expr,
65 $quantized_filename:expr,
66 $silent:expr,
67 $loading_uqff:expr
68 ) => {{
69 let api = {
70 use $crate::GLOBAL_HF_CACHE;
71 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
72 let mut api = ApiBuilder::from_cache(cache)
73 .with_progress(!$silent)
74 .with_token(get_token($token_source)?);
75 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
76 api = api.with_cache_dir(x.into());
77 }
78 api.build()?
79 };
80 let revision = $revision.unwrap_or("main".to_string());
81 let api = api.repo(Repo::with_revision(
82 $this.model_id.clone(),
83 RepoType::Model,
84 revision.clone(),
85 ));
86 let model_id = std::path::Path::new(&$this.model_id);
87 let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
88 info!("Using tokenizer.json at `{p}`");
89 PathBuf::from_str(p)?
90 } else {
91 info!("Loading `tokenizer.json` at `{}`", $this.model_id);
92 $crate::api_get_file!(api, "tokenizer.json", model_id)
93 };
94 info!("Loading `config.json` at `{}`", $this.model_id);
95 let config_filename = $crate::api_get_file!(api, "config.json", model_id);
96 let filenames = get_model_paths(
97 revision.clone(),
98 &$token_source,
99 &$quantized_model_id,
100 &$quantized_filename,
101 &api,
102 &model_id,
103 $loading_uqff,
104 )?;
105 let XLoraPaths {
106 adapter_configs,
107 adapter_safetensors,
108 classifier_path,
109 xlora_order,
110 xlora_config,
111 lora_preload_adapter_info,
112 } = get_xlora_paths(
113 $this.model_id.clone(),
114 &$this.xlora_model_id,
115 &$token_source,
116 revision.clone(),
117 &$this.xlora_order,
118 )?;
119 let gen_conf = if $crate::api_dir_list!(api, model_id)
120 .collect::<Vec<_>>()
121 .contains(&"generation_config.json".to_string())
122 {
123 info!("Loading `generation_config.json` at `{}`", $this.model_id);
124 Some($crate::api_get_file!(
125 api,
126 "generation_config.json",
127 model_id
128 ))
129 } else {
130 None
131 };
132 let preprocessor_config = if $crate::api_dir_list!(api, model_id)
133 .collect::<Vec<_>>()
134 .contains(&"preprocessor_config.json".to_string())
135 {
136 info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
137 Some($crate::api_get_file!(
138 api,
139 "preprocessor_config.json",
140 model_id
141 ))
142 } else {
143 None
144 };
145 let processor_config = if $crate::api_dir_list!(api, model_id)
146 .collect::<Vec<_>>()
147 .contains(&"processor_config.json".to_string())
148 {
149 info!("Loading `processor_config.json` at `{}`", $this.model_id);
150 Some($crate::api_get_file!(
151 api,
152 "processor_config.json",
153 model_id
154 ))
155 } else {
156 None
157 };
158 let template_filename = if let Some(ref p) = $this.chat_template {
159 info!("Using chat template file at `{p}`");
160 Some(PathBuf::from_str(p)?)
161 } else {
162 info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
163 Some($crate::api_get_file!(
164 api,
165 "tokenizer_config.json",
166 model_id
167 ))
168 };
169 let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
170 .collect::<Vec<_>>()
171 .contains(&"chat_template.json".to_string())
172 {
173 info!("Loading `chat_template.json` at `{}`", $this.model_id);
174 Some($crate::api_get_file!(api, "chat_template.json", model_id))
175 } else {
176 None
177 };
178 Ok(Box::new($path_name {
179 tokenizer_filename,
180 config_filename,
181 filenames,
182 xlora_adapter_configs: adapter_configs,
183 xlora_adapter_filenames: adapter_safetensors,
184 classifier_path,
185 classifier_config: xlora_config,
186 xlora_ordering: xlora_order,
187 template_filename,
188 gen_conf,
189 lora_preload_adapter_info,
190 preprocessor_config,
191 processor_config,
192 chat_template_json_filename,
193 }))
194 }};
195}
196
197#[doc(hidden)]
198#[macro_export]
199macro_rules! get_uqff_paths {
200 ($from_uqff:expr, $this:expr, $silent:expr) => {{
201 let api = {
202 use $crate::GLOBAL_HF_CACHE;
203 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
204 let mut api = ApiBuilder::from_cache(cache)
205 .with_progress(!$silent)
206 .with_token(get_token(
207 &$this
208 .token_source
209 .read()
210 .expect("Failed to read token source")
211 .clone()
212 .unwrap_or(TokenSource::None),
213 )?);
214 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
215 api = api.with_cache_dir(x.into());
216 }
217 api.build()?
218 };
219 let revision = $this
220 .revision
221 .read()
222 .expect("Failed to read revision")
223 .clone()
224 .unwrap_or("main".to_string());
225 let api = api.repo(Repo::with_revision(
226 $this.model_id.to_string(),
227 RepoType::Model,
228 revision.clone(),
229 ));
230
231 let file = $from_uqff.display().to_string();
232
233 api_get_file!(api, &file, Path::new(&$this.model_id))
234 }};
235}
236
237#[doc(hidden)]
238#[macro_export]
239macro_rules! get_paths_gguf {
240 (
241 $path_name:ident,
242 $token_source:expr,
243 $revision:expr,
244 $this:expr,
245 $quantized_model_id:expr,
246 $quantized_filenames:expr,
247 $silent:expr
248 ) => {{
249 let api = {
250 use $crate::GLOBAL_HF_CACHE;
251 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
252 let mut api = ApiBuilder::from_cache(cache)
253 .with_progress(!$silent)
254 .with_token(get_token($token_source)?);
255 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
256 api = api.with_cache_dir(x.into());
257 }
258 api.build()?
259 };
260 let revision = $revision.unwrap_or("main".to_string());
261 let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
262 let api = api.repo(Repo::with_revision(
263 this_model_id.clone(),
264 RepoType::Model,
265 revision.clone(),
266 ));
267 let model_id = std::path::Path::new(&this_model_id);
268
269 let chat_template = if let Some(ref p) = $this.chat_template {
270 if p.ends_with(".json") {
271 info!("Using chat template file at `{p}`");
272 Some(PathBuf::from_str(p)?)
273 } else {
274 panic!("Specified chat template file must end with .json");
275 }
276 } else {
277 if $this.model_id.is_none() {
278 None
279 } else {
280 info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
281 let res = $crate::api_get_file!(
282 api,
283 "tokenizer_config.json",
284 model_id
285 );
286 Some(res)
287 }
288 };
289
290 let filenames = get_model_paths(
291 revision.clone(),
292 &$token_source,
293 &Some($quantized_model_id),
294 &Some($quantized_filenames),
295 &api,
296 &model_id,
297 false, )?;
299
300 let XLoraPaths {
301 adapter_configs,
302 adapter_safetensors,
303 classifier_path,
304 xlora_order,
305 xlora_config,
306 lora_preload_adapter_info,
307 } = get_xlora_paths(
308 this_model_id.clone(),
309 &$this.xlora_model_id,
310 &$token_source,
311 revision.clone(),
312 &$this.xlora_order,
313 )?;
314
315 let gen_conf = if $crate::api_dir_list!(api, model_id)
316 .collect::<Vec<_>>()
317 .contains(&"generation_config.json".to_string())
318 {
319 info!("Loading `generation_config.json` at `{}`", this_model_id);
320 Some($crate::api_get_file!(
321 api,
322 "generation_config.json",
323 model_id
324 ))
325 } else {
326 None
327 };
328
329 let preprocessor_config = if $crate::api_dir_list!(api, model_id)
330 .collect::<Vec<_>>()
331 .contains(&"preprocessor_config.json".to_string())
332 {
333 info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
334 Some($crate::api_get_file!(
335 api,
336 "preprocessor_config.json",
337 model_id
338 ))
339 } else {
340 None
341 };
342
343 let processor_config = if $crate::api_dir_list!(api, model_id)
344 .collect::<Vec<_>>()
345 .contains(&"processor_config.json".to_string())
346 {
347 info!("Loading `processor_config.json` at `{}`", this_model_id);
348 Some($crate::api_get_file!(
349 api,
350 "processor_config.json",
351 model_id
352 ))
353 } else {
354 None
355 };
356
357 let tokenizer_filename = if $this.model_id.is_some() {
358 info!("Loading `tokenizer.json` at `{}`", this_model_id);
359 $crate::api_get_file!(api, "tokenizer.json", model_id)
360 } else {
361 PathBuf::from_str("")?
362 };
363
364 let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
365 .collect::<Vec<_>>()
366 .contains(&"chat_template.json".to_string())
367 {
368 info!("Loading `chat_template.json` at `{}`", this_model_id);
369 Some($crate::api_get_file!(
370 api,
371 "chat_template.json",
372 model_id
373 ))
374 } else {
375 None
376 };
377
378 Ok(Box::new($path_name {
379 tokenizer_filename,
380 config_filename: PathBuf::from_str("")?,
381 filenames,
382 xlora_adapter_configs: adapter_configs,
383 xlora_adapter_filenames: adapter_safetensors,
384 classifier_path,
385 classifier_config: xlora_config,
386 xlora_ordering: xlora_order,
387 template_filename: chat_template,
388 gen_conf,
389 lora_preload_adapter_info,
390 preprocessor_config,
391 processor_config,
392 chat_template_json_filename,
393 }))
394 }};
395}
396
397#[doc(hidden)]
398#[macro_export]
399macro_rules! normal_model_loader {
400 (
401 $paths:expr,
402 $dtype:expr,
403 $device:expr,
404 $layer_devices:expr,
405 $config:expr,
406 $loader:expr,
407 $use_flash_attn:expr,
408 $silent:expr,
409 $mapper:expr,
410 $loading_isq:expr,
411 $loading_uqff:expr,
412 $real_device:expr,
413 $attention_mechanism:expr,
414 $is_moqe:expr,
415 $multi_progress:expr,
416 ) => {{
417 let regexes = if $loading_isq && $loading_uqff {
418 Some(std::sync::Arc::new(if $is_moqe {
420 $loader.isq_layer_regexes_moqe(&$config)?
421 } else {
422 $loader.isq_layer_regexes(&$config)?
423 }))
424 } else {
425 None
426 };
427 let get_device_for_tensor =
428 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
429
430 let vb = from_mmaped_safetensors(
431 $paths.get_weight_filenames().to_vec(),
432 Vec::new(),
433 $dtype,
434 $device,
435 $layer_devices,
436 $silent,
437 regexes,
438 |_| true, get_device_for_tensor,
440 )?;
441
442 $loader.load(
443 &$config,
444 $use_flash_attn,
445 vb,
446 $crate::pipeline::NormalLoadingMetadata {
447 mapper: $mapper,
448 loading_isq: $loading_isq,
449 real_device: $real_device,
450 multi_progress: $multi_progress,
451 },
452 $attention_mechanism,
453 )?
454 }};
455}
456
457#[doc(hidden)]
458#[macro_export]
459macro_rules! normal_model_loader_sharded {
460 (
461 $vb:expr,
462 $config:expr,
463 $loader:expr,
464 $use_flash_attn:expr,
465 $mapper:expr,
466 $loading_isq:expr,
467 $real_device:expr,
468 $attention_mechanism:expr,
469 $multi_progress:expr,
470 ) => {{
471 $loader.load(
472 &$config,
473 $use_flash_attn,
474 $vb,
475 $crate::pipeline::NormalLoadingMetadata {
476 mapper: $mapper,
477 loading_isq: $loading_isq,
478 real_device: $real_device,
479 multi_progress: $multi_progress,
480 },
481 $attention_mechanism,
482 )?
483 }};
484}
485
486#[doc(hidden)]
487#[macro_export]
488macro_rules! vision_normal_model_loader {
489 (
490 $paths:expr,
491 $dtype:expr,
492 $device:expr,
493 $layer_devices:expr,
494 $config:expr,
495 $loader:expr,
496 $use_flash_attn:expr,
497 $silent:expr,
498 $mapper:expr,
499 $loading_isq:expr,
500 $loading_uqff:expr,
501 $real_device:expr,
502 $attention_mechanism:expr,
503 $multi_progress:expr,
504 ) => {{
505 let regexes = if $loading_isq && $loading_uqff {
506 Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
508 } else {
509 None
510 };
511 let get_device_for_tensor =
512 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
513
514 let vb = from_mmaped_safetensors(
515 $paths.get_weight_filenames().to_vec(),
516 Vec::new(),
517 $dtype,
518 $device,
519 $layer_devices,
520 $silent,
521 regexes,
522 |_| true, get_device_for_tensor,
524 )?;
525
526 $loader.load(
527 &$config,
528 $use_flash_attn,
529 vb,
530 $crate::pipeline::NormalLoadingMetadata {
531 mapper: $mapper,
532 loading_isq: $loading_isq,
533 real_device: $real_device,
534 multi_progress: $multi_progress,
535 },
536 $attention_mechanism,
537 )?
538 }};
539}
540
541#[doc(hidden)]
542#[macro_export]
543macro_rules! vision_normal_model_loader_sharded {
544 (
545 $vb:expr,
546 $config:expr,
547 $loader:expr,
548 $use_flash_attn:expr,
549 $mapper:expr,
550 $loading_isq:expr,
551 $real_device:expr,
552 $attention_mechanism:expr,
553 $multi_progress:expr,
554 ) => {{
555 $loader.load(
556 &$config,
557 $use_flash_attn,
558 $vb,
559 $crate::pipeline::NormalLoadingMetadata {
560 mapper: $mapper,
561 loading_isq: $loading_isq,
562 real_device: $real_device,
563 multi_progress: $multi_progress,
564 },
565 $attention_mechanism,
566 )?
567 }};
568}
569
570#[doc(hidden)]
571#[macro_export]
572macro_rules! xlora_model_loader {
573 (
574 $paths:expr,
575 $dtype:expr,
576 $device:expr,
577 $layer_devices:expr,
578 $config:expr,
579 $loader:expr,
580 $use_flash_attn:expr,
581 $silent:expr,
582 $mapper:expr,
583 $loading_isq:expr,
584 $real_device:expr,
585 $multi_progress:expr,
586 ) => {{
587 let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
588 safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap());
589 let get_device_for_tensor =
590 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
591
592 let vb = from_mmaped_safetensors(
593 safetensors_paths
594 .iter()
595 .map(|x| (*x).to_owned())
596 .collect::<Vec<_>>(),
597 $paths
598 .get_adapter_filenames()
599 .as_ref()
600 .unwrap()
601 .iter()
602 .map(|(_, x)| (*x).to_owned())
603 .collect::<Vec<_>>(),
604 $dtype,
605 $device,
606 $layer_devices,
607 $silent,
608 None,
609 |_| true,
610 get_device_for_tensor,
611 )?;
612
613 $loader.load_xlora(
614 &$config,
615 $use_flash_attn,
616 vb,
617 $paths.get_adapter_configs().as_ref().unwrap(),
618 Some($paths.get_classifier_config().as_ref().unwrap().clone()),
619 $paths.get_ordering().as_ref().unwrap().clone(),
620 $crate::pipeline::NormalLoadingMetadata {
621 mapper: $mapper,
622 loading_isq: $loading_isq,
623 real_device: $real_device,
624 multi_progress: $multi_progress,
625 },
626 &None,
627 )?
628 }};
629}
630
631#[doc(hidden)]
632#[macro_export]
633macro_rules! lora_model_loader {
634 (
635 $paths:expr,
636 $dtype:expr,
637 $device:expr,
638 $layer_devices:expr,
639 $config:expr,
640 $loader:expr,
641 $use_flash_attn:expr,
642 $silent:expr,
643 $mapper:expr,
644 $loading_isq:expr,
645 $real_device:expr,
646 $multi_progress:expr,
647 ) => {{
648 let safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
649 let get_device_for_tensor =
650 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
651
652 let vb = from_mmaped_safetensors(
653 safetensors_paths
654 .iter()
655 .map(|x| (*x).to_owned())
656 .collect::<Vec<_>>(),
657 $paths
658 .get_adapter_filenames()
659 .as_ref()
660 .unwrap()
661 .iter()
662 .map(|(_, x)| (*x).to_owned())
663 .collect::<Vec<_>>(),
664 Some($dtype),
665 $device,
666 $layer_devices,
667 $silent,
668 None,
669 |_| true,
670 get_device_for_tensor,
671 )?;
672
673 $loader.load_xlora(
674 &$config,
675 $use_flash_attn,
676 vb,
677 $paths.get_adapter_configs().as_ref().unwrap(),
678 None,
679 $paths.get_ordering().as_ref().unwrap().clone(),
680 $crate::pipeline::NormalLoadingMetadata {
681 mapper: $mapper,
682 loading_isq: $loading_isq,
683 real_device: $real_device,
684 multi_progress: $multi_progress,
685 },
686 &$crate::utils::varbuilder_utils::load_preload_adapters(
687 $paths.get_lora_preload_adapter_info(),
688 $dtype,
689 $device,
690 $silent,
691 )?,
692 )?
693 }};
694}