1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4 ($api:expr, $model_id:expr) => {
5 if std::path::Path::new($model_id).exists() {
6 let listing = std::fs::read_dir($model_id);
7 if listing.is_err() {
8 panic!("Cannot list directory {:?}", $model_id)
9 }
10 let listing = listing.unwrap();
11 listing
12 .into_iter()
13 .map(|s| {
14 s.unwrap()
15 .path()
16 .file_name()
17 .unwrap() .to_str()
19 .expect("Could not convert to str")
20 .to_string()
21 })
22 .collect::<Vec<String>>()
23 .into_iter()
24 } else {
25 $api.info()
26 .map(|repo| {
27 repo.siblings
28 .iter()
29 .map(|x| x.rfilename.clone())
30 .collect::<Vec<String>>()
31 })
32 .unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e))
33 .into_iter()
34 }
35 };
36}
37
38#[doc(hidden)]
39#[macro_export]
40macro_rules! api_get_file {
41 ($api:expr, $file:expr, $model_id:expr) => {
42 if std::path::Path::new($model_id).exists() {
43 let path = $model_id.join($file);
44 if !path.exists() {
45 panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
46 }
47 info!("Loading `{}` locally at `{}`", &$file, path.display());
48 path
49 } else {
50 $api.get($file)
51 .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
52 }
53 };
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! get_paths {
59 (
60 $path_name:ident,
61 $token_source:expr,
62 $revision:expr,
63 $this:expr,
64 $quantized_model_id:expr,
65 $quantized_filename:expr,
66 $silent:expr,
67 $loading_uqff:expr
68 ) => {{
69 let api = {
70 use $crate::GLOBAL_HF_CACHE;
71 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
72 let mut api = ApiBuilder::from_cache(cache)
73 .with_progress(!$silent)
74 .with_token(get_token($token_source)?);
75 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
76 api = api.with_cache_dir(x.into());
77 }
78 api.build()?
79 };
80 let revision = $revision.unwrap_or("main".to_string());
81 let api = api.repo(Repo::with_revision(
82 $this.model_id.clone(),
83 RepoType::Model,
84 revision.clone(),
85 ));
86 let model_id = std::path::Path::new(&$this.model_id);
87 let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
88 info!("Using tokenizer.json at `{p}`");
89 PathBuf::from_str(p)?
90 } else {
91 info!("Loading `tokenizer.json` at `{}`", $this.model_id);
92 $crate::api_get_file!(api, "tokenizer.json", model_id)
93 };
94 info!("Loading `config.json` at `{}`", $this.model_id);
95 let config_filename = $crate::api_get_file!(api, "config.json", model_id);
96 let filenames = get_model_paths(
97 revision.clone(),
98 &$token_source,
99 &$quantized_model_id,
100 &$quantized_filename,
101 &api,
102 &model_id,
103 $loading_uqff,
104 )?;
105 let adapter_paths = get_xlora_paths(
106 $this.model_id.clone(),
107 &$this.xlora_model_id,
108 &$this.lora_adapter_ids,
109 &$token_source,
110 revision.clone(),
111 &$this.xlora_order,
112 )?;
113 let gen_conf = if $crate::api_dir_list!(api, model_id)
114 .collect::<Vec<_>>()
115 .contains(&"generation_config.json".to_string())
116 {
117 info!("Loading `generation_config.json` at `{}`", $this.model_id);
118 Some($crate::api_get_file!(
119 api,
120 "generation_config.json",
121 model_id
122 ))
123 } else {
124 None
125 };
126 let preprocessor_config = if $crate::api_dir_list!(api, model_id)
127 .collect::<Vec<_>>()
128 .contains(&"preprocessor_config.json".to_string())
129 {
130 info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
131 Some($crate::api_get_file!(
132 api,
133 "preprocessor_config.json",
134 model_id
135 ))
136 } else {
137 None
138 };
139 let processor_config = if $crate::api_dir_list!(api, model_id)
140 .collect::<Vec<_>>()
141 .contains(&"processor_config.json".to_string())
142 {
143 info!("Loading `processor_config.json` at `{}`", $this.model_id);
144 Some($crate::api_get_file!(
145 api,
146 "processor_config.json",
147 model_id
148 ))
149 } else {
150 None
151 };
152 let template_filename = if let Some(ref p) = $this.chat_template {
153 info!("Using chat template file at `{p}`");
154 Some(PathBuf::from_str(p)?)
155 } else {
156 info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
157 Some($crate::api_get_file!(
158 api,
159 "tokenizer_config.json",
160 model_id
161 ))
162 };
163 let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
164 .collect::<Vec<_>>()
165 .contains(&"chat_template.json".to_string())
166 {
167 info!("Loading `chat_template.json` at `{}`", $this.model_id);
168 Some($crate::api_get_file!(api, "chat_template.json", model_id))
169 } else {
170 None
171 };
172 Ok(Box::new($path_name {
173 tokenizer_filename,
174 config_filename,
175 filenames,
176 adapter_paths,
177 template_filename,
178 gen_conf,
179 preprocessor_config,
180 processor_config,
181 chat_template_json_filename,
182 }))
183 }};
184}
185
186#[doc(hidden)]
187#[macro_export]
188macro_rules! get_uqff_paths {
189 ($from_uqff:expr, $this:expr, $silent:expr) => {{
190 let api = {
191 use $crate::GLOBAL_HF_CACHE;
192 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
193 let mut api = ApiBuilder::from_cache(cache)
194 .with_progress(!$silent)
195 .with_token(get_token(
196 &$this
197 .token_source
198 .read()
199 .expect("Failed to read token source")
200 .clone()
201 .unwrap_or(TokenSource::None),
202 )?);
203 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
204 api = api.with_cache_dir(x.into());
205 }
206 api.build()?
207 };
208 let revision = $this
209 .revision
210 .read()
211 .expect("Failed to read revision")
212 .clone()
213 .unwrap_or("main".to_string());
214 let api = api.repo(Repo::with_revision(
215 $this.model_id.to_string(),
216 RepoType::Model,
217 revision.clone(),
218 ));
219
220 let mut files = Vec::new();
221 for file in $from_uqff {
222 let file = file.display().to_string();
223
224 files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
225 }
226 files
227 }};
228}
229
230#[doc(hidden)]
231#[macro_export]
232macro_rules! get_paths_gguf {
233 (
234 $path_name:ident,
235 $token_source:expr,
236 $revision:expr,
237 $this:expr,
238 $quantized_model_id:expr,
239 $quantized_filenames:expr,
240 $silent:expr
241 ) => {{
242 let api = {
243 use $crate::GLOBAL_HF_CACHE;
244 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
245 let mut api = ApiBuilder::from_cache(cache)
246 .with_progress(!$silent)
247 .with_token(get_token($token_source)?);
248 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
249 api = api.with_cache_dir(x.into());
250 }
251 api.build()?
252 };
253 let revision = $revision.unwrap_or("main".to_string());
254 let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
255 let api = api.repo(Repo::with_revision(
256 this_model_id.clone(),
257 RepoType::Model,
258 revision.clone(),
259 ));
260 let model_id = std::path::Path::new(&this_model_id);
261
262 let chat_template = if let Some(ref p) = $this.chat_template {
263 if p.ends_with(".json") {
264 info!("Using chat template file at `{p}`");
265 Some(PathBuf::from_str(p)?)
266 } else {
267 panic!("Specified chat template file must end with .json");
268 }
269 } else {
270 if $this.model_id.is_none() {
271 None
272 } else {
273 info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
274 let res = $crate::api_get_file!(
275 api,
276 "tokenizer_config.json",
277 model_id
278 );
279 Some(res)
280 }
281 };
282
283 let filenames = get_model_paths(
284 revision.clone(),
285 &$token_source,
286 &Some($quantized_model_id),
287 &Some($quantized_filenames),
288 &api,
289 &model_id,
290 false, )?;
292
293 let adapter_paths = get_xlora_paths(
294 this_model_id.clone(),
295 &$this.xlora_model_id,
296 &$this.lora_adapter_ids,
297 &$token_source,
298 revision.clone(),
299 &$this.xlora_order,
300 )?;
301
302 let gen_conf = if $crate::api_dir_list!(api, model_id)
303 .collect::<Vec<_>>()
304 .contains(&"generation_config.json".to_string())
305 {
306 info!("Loading `generation_config.json` at `{}`", this_model_id);
307 Some($crate::api_get_file!(
308 api,
309 "generation_config.json",
310 model_id
311 ))
312 } else {
313 None
314 };
315
316 let preprocessor_config = if $crate::api_dir_list!(api, model_id)
317 .collect::<Vec<_>>()
318 .contains(&"preprocessor_config.json".to_string())
319 {
320 info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
321 Some($crate::api_get_file!(
322 api,
323 "preprocessor_config.json",
324 model_id
325 ))
326 } else {
327 None
328 };
329
330 let processor_config = if $crate::api_dir_list!(api, model_id)
331 .collect::<Vec<_>>()
332 .contains(&"processor_config.json".to_string())
333 {
334 info!("Loading `processor_config.json` at `{}`", this_model_id);
335 Some($crate::api_get_file!(
336 api,
337 "processor_config.json",
338 model_id
339 ))
340 } else {
341 None
342 };
343
344 let tokenizer_filename = if $this.model_id.is_some() {
345 info!("Loading `tokenizer.json` at `{}`", this_model_id);
346 $crate::api_get_file!(api, "tokenizer.json", model_id)
347 } else {
348 PathBuf::from_str("")?
349 };
350
351 let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
352 .collect::<Vec<_>>()
353 .contains(&"chat_template.json".to_string())
354 {
355 info!("Loading `chat_template.json` at `{}`", this_model_id);
356 Some($crate::api_get_file!(
357 api,
358 "chat_template.json",
359 model_id
360 ))
361 } else {
362 None
363 };
364
365 Ok(Box::new($path_name {
366 tokenizer_filename,
367 config_filename: PathBuf::from_str("")?,
368 filenames,
369 adapter_paths,
370 template_filename: chat_template,
371 gen_conf,
372 preprocessor_config,
373 processor_config,
374 chat_template_json_filename,
375 }))
376 }};
377}
378
379#[doc(hidden)]
380#[macro_export]
381macro_rules! normal_model_loader {
382 (
383 $paths:expr,
384 $dtype:expr,
385 $device:expr,
386 $layer_devices:expr,
387 $config:expr,
388 $loader:expr,
389 $use_flash_attn:expr,
390 $silent:expr,
391 $mapper:expr,
392 $loading_isq:expr,
393 $loading_uqff:expr,
394 $real_device:expr,
395 $attention_mechanism:expr,
396 $is_moqe:expr,
397 $multi_progress:expr,
398 ) => {{
399 let regexes = if $loading_isq && $loading_uqff {
400 Some(std::sync::Arc::new(if $is_moqe {
402 $loader.isq_layer_regexes_moqe(&$config)?
403 } else {
404 $loader.isq_layer_regexes(&$config)?
405 }))
406 } else {
407 None
408 };
409 let get_device_for_tensor =
410 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
411
412 let vb = from_mmaped_safetensors(
413 $paths.get_weight_filenames().to_vec(),
414 Vec::new(),
415 $dtype,
416 $device,
417 $layer_devices,
418 $silent,
419 regexes,
420 |_| true, get_device_for_tensor,
422 )?;
423
424 $loader.load(
425 &$config,
426 $use_flash_attn,
427 vb,
428 $crate::pipeline::NormalLoadingMetadata {
429 mapper: $mapper,
430 loading_isq: $loading_isq,
431 real_device: $real_device,
432 multi_progress: $multi_progress,
433 },
434 $attention_mechanism,
435 )?
436 }};
437}
438
439#[doc(hidden)]
440#[macro_export]
441macro_rules! normal_model_loader_sharded {
442 (
443 $vb:expr,
444 $config:expr,
445 $loader:expr,
446 $use_flash_attn:expr,
447 $mapper:expr,
448 $loading_isq:expr,
449 $real_device:expr,
450 $attention_mechanism:expr,
451 $multi_progress:expr,
452 ) => {{
453 $loader.load(
454 &$config,
455 $use_flash_attn,
456 $vb,
457 $crate::pipeline::NormalLoadingMetadata {
458 mapper: $mapper,
459 loading_isq: $loading_isq,
460 real_device: $real_device,
461 multi_progress: $multi_progress,
462 },
463 $attention_mechanism,
464 )?
465 }};
466}
467
468#[doc(hidden)]
469#[macro_export]
470macro_rules! vision_normal_model_loader {
471 (
472 $paths:expr,
473 $dtype:expr,
474 $device:expr,
475 $layer_devices:expr,
476 $config:expr,
477 $loader:expr,
478 $use_flash_attn:expr,
479 $silent:expr,
480 $mapper:expr,
481 $loading_isq:expr,
482 $loading_uqff:expr,
483 $real_device:expr,
484 $attention_mechanism:expr,
485 $multi_progress:expr,
486 ) => {{
487 let regexes = if $loading_isq && $loading_uqff {
488 Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
490 } else {
491 None
492 };
493 let get_device_for_tensor =
494 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
495
496 let vb = from_mmaped_safetensors(
497 $paths.get_weight_filenames().to_vec(),
498 Vec::new(),
499 $dtype,
500 $device,
501 $layer_devices,
502 $silent,
503 regexes,
504 |_| true, get_device_for_tensor,
506 )?;
507
508 $loader.load(
509 &$config,
510 $use_flash_attn,
511 vb,
512 $crate::pipeline::NormalLoadingMetadata {
513 mapper: $mapper,
514 loading_isq: $loading_isq,
515 real_device: $real_device,
516 multi_progress: $multi_progress,
517 },
518 $attention_mechanism,
519 )?
520 }};
521}
522
523#[doc(hidden)]
524#[macro_export]
525macro_rules! vision_normal_model_loader_sharded {
526 (
527 $vb:expr,
528 $config:expr,
529 $loader:expr,
530 $use_flash_attn:expr,
531 $mapper:expr,
532 $loading_isq:expr,
533 $real_device:expr,
534 $attention_mechanism:expr,
535 $multi_progress:expr,
536 ) => {{
537 $loader.load(
538 &$config,
539 $use_flash_attn,
540 $vb,
541 $crate::pipeline::NormalLoadingMetadata {
542 mapper: $mapper,
543 loading_isq: $loading_isq,
544 real_device: $real_device,
545 multi_progress: $multi_progress,
546 },
547 $attention_mechanism,
548 )?
549 }};
550}
551
552#[doc(hidden)]
553#[macro_export]
554macro_rules! xlora_model_loader {
555 (
556 $paths:expr,
557 $dtype:expr,
558 $device:expr,
559 $layer_devices:expr,
560 $config:expr,
561 $loader:expr,
562 $use_flash_attn:expr,
563 $silent:expr,
564 $mapper:expr,
565 $loading_isq:expr,
566 $real_device:expr,
567 $multi_progress:expr,
568 ) => {{
569 let $crate::pipeline::AdapterPaths::XLora {
571 adapter_configs,
572 adapter_safetensors,
573 classifier_path,
574 xlora_order,
575 xlora_config,
576 lora_preload_adapter_info: _,
577 } = $paths.get_adapter_paths()
578 else {
579 unreachable!()
580 };
581
582 let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
583 safetensors_paths.push(classifier_path.as_ref().unwrap());
584 let get_device_for_tensor =
585 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
586
587 let vb = from_mmaped_safetensors(
588 safetensors_paths
589 .iter()
590 .map(|x| (*x).to_owned())
591 .collect::<Vec<_>>(),
592 adapter_safetensors
593 .as_ref()
594 .unwrap()
595 .iter()
596 .map(|(_, x)| (*x).to_owned())
597 .collect::<Vec<_>>(),
598 $dtype,
599 $device,
600 $layer_devices,
601 $silent,
602 None,
603 |_| true,
604 get_device_for_tensor,
605 )?;
606
607 $loader.load_xlora(
608 &$config,
609 $use_flash_attn,
610 vb,
611 adapter_configs.as_ref().unwrap(),
612 Some(xlora_config.as_ref().unwrap().clone()),
613 xlora_order.as_ref().unwrap().clone(),
614 $crate::pipeline::NormalLoadingMetadata {
615 mapper: $mapper,
616 loading_isq: $loading_isq,
617 real_device: $real_device,
618 multi_progress: $multi_progress,
619 },
620 &None,
621 )?
622 }};
623}
624
625#[doc(hidden)]
626#[macro_export]
627macro_rules! lora_model_loader {
628 (
629 $paths:expr,
630 $dtype:expr,
631 $device:expr,
632 $layer_devices:expr,
633 $config:expr,
634 $loader:expr,
635 $use_flash_attn:expr,
636 $silent:expr,
637 $mapper:expr,
638 $loading_isq:expr,
639 $loading_uqff:expr,
640 $real_device:expr,
641 $attention_mechanism:expr,
642 $is_moqe:expr,
643 $multi_progress:expr,
644 ) => {{
645 let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
646 else {
647 unreachable!()
648 };
649
650 let regexes = if $loading_isq && $loading_uqff {
651 Some(std::sync::Arc::new(if $is_moqe {
653 $loader.isq_layer_regexes_moqe(&$config)?
654 } else {
655 $loader.isq_layer_regexes(&$config)?
656 }))
657 } else {
658 None
659 };
660 let get_device_for_tensor =
661 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
662
663 let vb = from_mmaped_safetensors(
664 $paths.get_weight_filenames().to_vec(),
665 Vec::new(),
666 $dtype,
667 $device,
668 $layer_devices,
669 $silent,
670 regexes,
671 |_| true, get_device_for_tensor.clone(),
673 )?;
674
675 for $crate::pipeline::LoraAdapterPaths {
676 adapter_path,
677 lora_config,
678 } in lora_adapter_paths
679 {
680 let lora_vb = from_mmaped_safetensors(
681 vec![adapter_path.clone()],
682 Vec::new(),
683 $dtype,
684 $device,
685 $layer_devices,
686 $silent,
687 None,
688 |_| true,
689 get_device_for_tensor.clone(),
690 )?;
691
692 mistralrs_quant::APPLIED_LORAS
693 .lock()
694 .unwrap()
695 .push(mistralrs_quant::LoraAdapter {
696 config: lora_config.clone(),
697 weights: lora_vb,
698 });
699 }
700
701 $loader.load(
702 &$config,
703 $use_flash_attn,
704 vb,
705 $crate::pipeline::NormalLoadingMetadata {
706 mapper: $mapper,
707 loading_isq: $loading_isq,
708 real_device: $real_device,
709 multi_progress: $multi_progress,
710 },
711 $attention_mechanism,
712 )?
713 }};
714}