1#[doc(hidden)]
2#[macro_export]
3macro_rules! api_dir_list {
4 ($api:expr, $model_id:expr, $should_panic:expr) => {
5 if std::path::Path::new($model_id).exists() {
6 let listing = std::fs::read_dir($model_id);
7 if listing.is_err() {
8 panic!("Cannot list directory {:?}", $model_id)
9 }
10 let listing = listing.unwrap();
11 listing
12 .into_iter()
13 .map(|s| {
14 s.unwrap()
15 .path()
16 .file_name()
17 .unwrap() .to_str()
19 .expect("Could not convert to str")
20 .to_string()
21 })
22 .collect::<Vec<String>>()
23 .into_iter()
24 } else {
25 let sanitized_id = std::path::Path::new($model_id)
26 .display()
27 .to_string()
28 .replace("/", "-");
29
30 let home_folder = if dirs::home_dir().is_some() {
31 let mut path = dirs::home_dir().unwrap();
32 path.push(".cache/huggingface/hub/");
33 if !path.exists() {
34 let _ = std::fs::create_dir_all(&path);
35 }
36 path
37 } else {
38 "./".into()
39 };
40
41 let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
42 .map(std::path::PathBuf::from)
43 .unwrap_or(home_folder.into());
44 let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
45 if std::path::Path::new(&cache_file).exists() {
46 use std::io::Read;
47 let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
49 let mut contents = String::new();
50 file.read_to_string(&mut contents)
51 .expect("Could not read cache file");
52 let cache: $crate::pipeline::FileListCache =
53 serde_json::from_str(&contents).expect("Could not parse cache JSON");
54 tracing::info!("Read from cache file {:?}", cache_file);
55 cache.files.into_iter()
56 } else {
57 $api.info()
58 .map(|repo| {
59 let files: Vec<String> = repo
60 .siblings
61 .iter()
62 .map(|x| x.rfilename.clone())
63 .collect::<Vec<String>>();
64 let cache = $crate::pipeline::FileListCache {
66 files: files.clone(),
67 };
68 let json = serde_json::to_string_pretty(&cache)
69 .expect("Could not serialize cache");
70 let ret = std::fs::write(&cache_file, json);
71 tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
72 files
73 })
74 .unwrap_or_else(|e| {
75 if $should_panic {
76 panic!("Could not get directory listing from API: {:?}", e)
77 } else {
78 tracing::warn!("Could not get directory listing from API: {:?}", e);
79 Vec::<String>::new()
80 }
81 })
82 .into_iter()
83 }
84 }
85 };
86}
87
88#[doc(hidden)]
89#[macro_export]
90macro_rules! api_get_file {
91 ($api:expr, $file:expr, $model_id:expr) => {
92 if std::path::Path::new($model_id).exists() {
93 let path = $model_id.join($file);
94 if !path.exists() {
95 panic!("File \"{}\" not found at model id {:?}", $file, $model_id)
96 }
97 info!("Loading `{}` locally at `{}`", &$file, path.display());
98 path
99 } else {
100 $api.get($file)
101 .unwrap_or_else(|e| panic!("Could not get file {:?} from API: {:?}", $file, e))
102 }
103 };
104}
105
106#[doc(hidden)]
107#[macro_export]
108macro_rules! get_paths {
109 (
110 $path_name:ident,
111 $token_source:expr,
112 $revision:expr,
113 $this:expr,
114 $quantized_model_id:expr,
115 $quantized_filename:expr,
116 $silent:expr,
117 $loading_uqff:expr
118 ) => {{
119 let api = {
120 use $crate::GLOBAL_HF_CACHE;
121 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
122 let mut api = ApiBuilder::from_cache(cache)
123 .with_progress(!$silent)
124 .with_token(get_token($token_source)?);
125 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
126 api = api.with_cache_dir(x.into());
127 }
128 api.build()?
129 };
130 let revision = $revision.unwrap_or("main".to_string());
131 let api = api.repo(Repo::with_revision(
132 $this.model_id.clone(),
133 RepoType::Model,
134 revision.clone(),
135 ));
136 let model_id = std::path::Path::new(&$this.model_id);
137 let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
138 info!("Using tokenizer.json at `{p}`");
139 PathBuf::from_str(p)?
140 } else {
141 info!("Loading `tokenizer.json` at `{}`", $this.model_id);
142 $crate::api_get_file!(api, "tokenizer.json", model_id)
143 };
144 info!("Loading `config.json` at `{}`", $this.model_id);
145 let config_filename = $crate::api_get_file!(api, "config.json", model_id);
146 let filenames = get_model_paths(
147 revision.clone(),
148 &$token_source,
149 $quantized_model_id.as_ref(),
150 $quantized_filename.as_ref(),
151 &api,
152 &model_id,
153 $loading_uqff,
154 )?;
155 let adapter_paths = get_xlora_paths(
156 $this.model_id.clone(),
157 $this.xlora_model_id.as_ref(),
158 $this.lora_adapter_ids.as_ref(),
159 &$token_source,
160 revision.clone(),
161 $this.xlora_order.as_ref(),
162 )?;
163 let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();
164
165 let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
166 info!("Loading `generation_config.json` at `{}`", $this.model_id);
167 Some($crate::api_get_file!(
168 api,
169 "generation_config.json",
170 model_id
171 ))
172 } else {
173 None
174 };
175 let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
176 info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
177 Some($crate::api_get_file!(
178 api,
179 "preprocessor_config.json",
180 model_id
181 ))
182 } else {
183 None
184 };
185 let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
186 info!("Loading `processor_config.json` at `{}`", $this.model_id);
187 Some($crate::api_get_file!(
188 api,
189 "processor_config.json",
190 model_id
191 ))
192 } else {
193 None
194 };
195 let template_filename = if let Some(ref p) = $this.chat_template {
196 info!("Using chat template file at `{p}`");
197 Some(PathBuf::from_str(p)?)
198 } else if dir_list.contains(&"chat_template.jinja".to_string()) {
199 info!("Loading `chat_template.jinja` at `{}`", $this.model_id);
200 Some($crate::api_get_file!(api, "chat_template.jinja", model_id))
201 } else {
202 info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
203 Some($crate::api_get_file!(
204 api,
205 "tokenizer_config.json",
206 model_id
207 ))
208 };
209 let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
210 info!("Loading `chat_template.json` at `{}`", $this.model_id);
211 Some($crate::api_get_file!(api, "chat_template.json", model_id))
212 } else {
213 None
214 };
215 Ok(Box::new($path_name {
216 tokenizer_filename,
217 config_filename,
218 filenames,
219 adapter_paths,
220 template_filename,
221 gen_conf,
222 preprocessor_config,
223 processor_config,
224 chat_template_json_filename,
225 }))
226 }};
227}
228
229#[doc(hidden)]
230#[macro_export]
231macro_rules! get_embedding_paths {
232 (
233 $path_name:ident,
234 $token_source:expr,
235 $revision:expr,
236 $this:expr,
237 $quantized_model_id:expr,
238 $quantized_filename:expr,
239 $silent:expr,
240 $loading_uqff:expr
241 ) => {{
242 let api = {
243 use $crate::GLOBAL_HF_CACHE;
244 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
245 let mut api = ApiBuilder::from_cache(cache)
246 .with_progress(!$silent)
247 .with_token(get_token($token_source)?);
248 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
249 api = api.with_cache_dir(x.into());
250 }
251 api.build()?
252 };
253 let revision = $revision.unwrap_or("main".to_string());
254 let api = api.repo(Repo::with_revision(
255 $this.model_id.clone(),
256 RepoType::Model,
257 revision.clone(),
258 ));
259 let model_id = std::path::Path::new(&$this.model_id);
260 let tokenizer_filename = if let Some(ref p) = $this.tokenizer_json {
261 info!("Using tokenizer.json at `{p}`");
262 PathBuf::from_str(p)?
263 } else {
264 info!("Loading `tokenizer.json` at `{}`", $this.model_id);
265 $crate::api_get_file!(api, "tokenizer.json", model_id)
266 };
267 info!("Loading `config.json` at `{}`", $this.model_id);
268 let config_filename = $crate::api_get_file!(api, "config.json", model_id);
269 let filenames = get_model_paths(
270 revision.clone(),
271 &$token_source,
272 $quantized_model_id.as_ref(),
273 $quantized_filename.as_ref(),
274 &api,
275 &model_id,
276 $loading_uqff,
277 )?;
278 let adapter_paths = get_xlora_paths(
279 $this.model_id.clone(),
280 None, $this.lora_adapter_ids.as_ref(),
282 &$token_source,
283 revision.clone(),
284 None, )?;
286
287 let mut parsed_modules = Vec::new();
288 let is_local = std::path::Path::new(&$this.model_id).exists();
289 let modules_path = if is_local {
290 model_id.join("modules.json")
291 } else {
292 $crate::api_get_file!(api, "modules.json", model_id)
293 };
294
295 if modules_path.exists() {
296 let modules: Vec<$crate::pipeline::EmbeddingModule> =
297 serde_json::from_str(&std::fs::read_to_string(&modules_path)?)?;
298 for module in modules {
299 match module.ty {
300 $crate::pipeline::EmbeddingModuleType::Transformer => {
301 parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Transformer {
302 path: module.path.clone(),
303 });
304 }
305 $crate::pipeline::EmbeddingModuleType::Pooling => {
306 parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Pooling {
307 path: module.path.clone(),
308 config: $crate::api_get_file!(
309 api,
310 &format!("{}/config.json", module.path),
311 model_id
312 ),
313 });
314 }
315 $crate::pipeline::EmbeddingModuleType::Dense => {
316 parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Dense {
317 path: module.path.clone(),
318 config: $crate::api_get_file!(
319 api,
320 &format!("{}/config.json", module.path),
321 model_id
322 ),
323 model: $crate::api_get_file!(
324 api,
325 &format!("{}/model.safetensors", module.path),
326 model_id
327 ),
328 });
329 }
330 $crate::pipeline::EmbeddingModuleType::Normalize => {
331 parsed_modules.push($crate::pipeline::EmbeddingModulePaths::Normalize {
332 path: module.path.clone(),
333 });
334 }
335 }
336 }
337 }
338
339 Ok(Box::new($path_name {
340 tokenizer_filename,
341 config_filename,
342 filenames,
343 adapter_paths,
344 modules: parsed_modules,
345 }))
346 }};
347}
348
349#[doc(hidden)]
350#[macro_export]
351macro_rules! get_uqff_paths {
352 ($from_uqff:expr, $this:expr, $silent:expr) => {{
353 let api = {
354 use $crate::GLOBAL_HF_CACHE;
355 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
356 let mut api = ApiBuilder::from_cache(cache)
357 .with_progress(!$silent)
358 .with_token(get_token(
359 &$this
360 .token_source
361 .read()
362 .expect("Failed to read token source")
363 .clone()
364 .unwrap_or(TokenSource::None),
365 )?);
366 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
367 api = api.with_cache_dir(x.into());
368 }
369 api.build()?
370 };
371 let revision = $this
372 .revision
373 .read()
374 .expect("Failed to read revision")
375 .clone()
376 .unwrap_or("main".to_string());
377 let api = api.repo(Repo::with_revision(
378 $this.model_id.to_string(),
379 RepoType::Model,
380 revision.clone(),
381 ));
382
383 let mut files = Vec::new();
384 for file in $from_uqff {
385 let file = file.display().to_string();
386
387 files.push(api_get_file!(api, &file, Path::new(&$this.model_id)));
388 }
389 files
390 }};
391}
392
393#[doc(hidden)]
394#[macro_export]
395macro_rules! get_paths_gguf {
396 (
397 $path_name:ident,
398 $token_source:expr,
399 $revision:expr,
400 $this:expr,
401 $quantized_model_id:expr,
402 $quantized_filenames:expr,
403 $silent:expr
404 ) => {{
405 let api = {
406 use $crate::GLOBAL_HF_CACHE;
407 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
408 let mut api = ApiBuilder::from_cache(cache)
409 .with_progress(!$silent)
410 .with_token(get_token($token_source)?);
411 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
412 api = api.with_cache_dir(x.into());
413 }
414 api.build()?
415 };
416 let revision = $revision.unwrap_or("main".to_string());
417 let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
418 let api = api.repo(Repo::with_revision(
419 this_model_id.clone(),
420 RepoType::Model,
421 revision.clone(),
422 ));
423 let model_id = std::path::Path::new(&this_model_id);
424
425 let dir_list = $crate::api_dir_list!(api, model_id, false)
426 .collect::<Vec<_>>();
427
428 let chat_template = if let Some(ref p) = $this.chat_template {
429 if p.ends_with(".json") || p.ends_with(".jinja") {
430 info!("Using chat template file at `{p}`");
431 Some(PathBuf::from_str(p)?)
432 } else {
433 panic!("Specified chat template file must end with .json or .jinja");
434 }
435 } else {
436 if $this.model_id.is_none() {
437 None
438 } else if dir_list.contains(&"chat_template.jinja".to_string()) {
439 info!("Loading `chat_template.jinja` at `{}`", this_model_id);
440 Some($crate::api_get_file!(
441 api,
442 "chat_template.jinja",
443 model_id
444 ))
445 } else {
446 info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
447 let res = $crate::api_get_file!(
448 api,
449 "tokenizer_config.json",
450 model_id
451 );
452 Some(res)
453 }
454 };
455
456 let filenames = get_model_paths(
457 revision.clone(),
458 &$token_source,
459 Some(&$quantized_model_id),
460 Some(&$quantized_filenames),
461 &api,
462 &model_id,
463 false, )?;
465
466 info!("GGUF file(s) {:?}", filenames);
467 let adapter_paths = get_xlora_paths(
468 this_model_id.clone(),
469 $this.xlora_model_id.as_ref(),
470 $this.lora_adapter_ids.as_ref(),
471 &$token_source,
472 revision.clone(),
473 $this.xlora_order.as_ref(),
474 )?;
475
476 let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
477 info!("Loading `generation_config.json` at `{}`", this_model_id);
478 Some($crate::api_get_file!(
479 api,
480 "generation_config.json",
481 model_id
482 ))
483 } else {
484 None
485 };
486
487 let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string())
488 {
489 info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
490 Some($crate::api_get_file!(
491 api,
492 "preprocessor_config.json",
493 model_id
494 ))
495 } else {
496 None
497 };
498
499 let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
500 info!("Loading `processor_config.json` at `{}`", this_model_id);
501 Some($crate::api_get_file!(
502 api,
503 "processor_config.json",
504 model_id
505 ))
506 } else {
507 None
508 };
509
510 let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) {
511 info!("Loading `tokenizer.json` at `{}`", this_model_id);
512 $crate::api_get_file!(api, "tokenizer.json", model_id)
513 } else {
514 PathBuf::from_str("")?
515 };
516
517 let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
518 info!("Loading `chat_template.json` at `{}`", this_model_id);
519 Some($crate::api_get_file!(
520 api,
521 "chat_template.json",
522 model_id
523 ))
524 } else {
525 None
526 };
527
528 Ok(Box::new($path_name {
529 tokenizer_filename,
530 config_filename: PathBuf::from_str("")?,
531 filenames,
532 adapter_paths,
533 template_filename: chat_template,
534 gen_conf,
535 preprocessor_config,
536 processor_config,
537 chat_template_json_filename,
538 }))
539 }};
540}
541
542#[doc(hidden)]
543#[macro_export]
544macro_rules! normal_model_loader {
545 (
546 $paths:expr,
547 $dtype:expr,
548 $device:expr,
549 $layer_devices:expr,
550 $config:expr,
551 $loader:expr,
552 $silent:expr,
553 $mapper:expr,
554 $loading_isq:expr,
555 $loading_uqff:expr,
556 $real_device:expr,
557 $attention_mechanism:expr,
558 $is_moqe:expr,
559 $multi_progress:expr,
560 $matformer_config:expr,
561 ) => {{
562 let regexes = if $loading_isq && $loading_uqff {
563 Some(std::sync::Arc::new(if $is_moqe {
565 $loader.isq_layer_regexes_moqe(&$config)?
566 } else {
567 $loader.isq_layer_regexes(&$config)?
568 }))
569 } else {
570 None
571 };
572 let get_device_for_tensor =
573 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
574
575 let vb = from_mmaped_safetensors(
576 $paths.get_weight_filenames().to_vec(),
577 Vec::new(),
578 $dtype,
579 $device,
580 $layer_devices,
581 $silent,
582 regexes,
583 |_| true, get_device_for_tensor,
585 )?;
586
587 $loader.load(
588 &$config,
589 vb,
590 $crate::pipeline::NormalLoadingMetadata {
591 mapper: $mapper,
592 loading_isq: $loading_isq,
593 real_device: $real_device,
594 multi_progress: $multi_progress,
595 matformer_slicing_config: $matformer_config,
596 },
597 $attention_mechanism,
598 )?
599 }};
600}
601
602#[doc(hidden)]
603#[macro_export]
604macro_rules! normal_model_loader_sharded {
605 (
606 $vb:expr,
607 $config:expr,
608 $loader:expr,
609 $mapper:expr,
610 $loading_isq:expr,
611 $real_device:expr,
612 $attention_mechanism:expr,
613 $multi_progress:expr,
614 $matformer_config:expr,
615 ) => {{
616 $loader.load(
617 &$config,
618 $vb,
619 $crate::pipeline::NormalLoadingMetadata {
620 mapper: $mapper,
621 loading_isq: $loading_isq,
622 real_device: $real_device,
623 multi_progress: $multi_progress,
624 matformer_slicing_config: $matformer_config,
625 },
626 $attention_mechanism,
627 )?
628 }};
629}
630
631#[doc(hidden)]
632#[macro_export]
633macro_rules! vision_normal_model_loader {
634 (
635 $paths:expr,
636 $dtype:expr,
637 $device:expr,
638 $layer_devices:expr,
639 $config:expr,
640 $loader:expr,
641 $silent:expr,
642 $mapper:expr,
643 $loading_isq:expr,
644 $loading_uqff:expr,
645 $real_device:expr,
646 $attention_mechanism:expr,
647 $multi_progress:expr,
648 $matformer_config:expr,
649 ) => {{
650 let regexes = if $loading_isq && $loading_uqff {
651 Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
653 } else {
654 None
655 };
656 let get_device_for_tensor =
657 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
658
659 let vb = from_mmaped_safetensors(
660 $paths.get_weight_filenames().to_vec(),
661 Vec::new(),
662 $dtype,
663 $device,
664 $layer_devices,
665 $silent,
666 regexes,
667 |_| true, get_device_for_tensor,
669 )?;
670
671 $loader.load(
672 &$config,
673 vb,
674 $crate::pipeline::NormalLoadingMetadata {
675 mapper: $mapper,
676 loading_isq: $loading_isq,
677 real_device: $real_device,
678 multi_progress: $multi_progress,
679 matformer_slicing_config: $matformer_config,
680 },
681 $attention_mechanism,
682 )?
683 }};
684}
685
686#[doc(hidden)]
687#[macro_export]
688macro_rules! vision_normal_model_loader_sharded {
689 (
690 $vb:expr,
691 $config:expr,
692 $loader:expr,
693 $mapper:expr,
694 $loading_isq:expr,
695 $real_device:expr,
696 $attention_mechanism:expr,
697 $multi_progress:expr,
698 $matformer_config:expr,
699 ) => {{
700 $loader.load(
701 &$config,
702 $vb,
703 $crate::pipeline::NormalLoadingMetadata {
704 mapper: $mapper,
705 loading_isq: $loading_isq,
706 real_device: $real_device,
707 multi_progress: $multi_progress,
708 matformer_slicing_config: $matformer_config,
709 },
710 $attention_mechanism,
711 )?
712 }};
713}
714
715#[doc(hidden)]
716#[macro_export]
717macro_rules! embedding_normal_model_loader {
718 (
719 $paths:expr,
720 $dtype:expr,
721 $device:expr,
722 $layer_devices:expr,
723 $config:expr,
724 $loader:expr,
725 $silent:expr,
726 $mapper:expr,
727 $loading_isq:expr,
728 $loading_uqff:expr,
729 $real_device:expr,
730 $attention_mechanism:expr,
731 $multi_progress:expr,
732 ) => {{
733 let regexes = if $loading_isq && $loading_uqff {
734 Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
736 } else {
737 None
738 };
739 let get_device_for_tensor =
740 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
741
742 let vb = from_mmaped_safetensors(
743 $paths.get_weight_filenames().to_vec(),
744 Vec::new(),
745 $dtype,
746 $device,
747 $layer_devices,
748 $silent,
749 regexes,
750 |_| true, get_device_for_tensor,
752 )?;
753
754 $loader.load(
755 &$config,
756 vb,
757 $crate::pipeline::NormalLoadingMetadata {
758 mapper: $mapper,
759 loading_isq: $loading_isq,
760 real_device: $real_device,
761 multi_progress: $multi_progress,
762 matformer_slicing_config: None,
763 },
764 $attention_mechanism,
765 )?
766 }};
767}
768
769#[doc(hidden)]
770#[macro_export]
771macro_rules! embedding_normal_model_loader_sharded {
772 (
773 $vb:expr,
774 $config:expr,
775 $loader:expr,
776 $mapper:expr,
777 $loading_isq:expr,
778 $real_device:expr,
779 $attention_mechanism:expr,
780 $multi_progress:expr,
781 ) => {{
782 $loader.load(
783 &$config,
784 $vb,
785 $crate::pipeline::NormalLoadingMetadata {
786 mapper: $mapper,
787 loading_isq: $loading_isq,
788 real_device: $real_device,
789 multi_progress: $multi_progress,
790 matformer_slicing_config: None,
791 },
792 $attention_mechanism,
793 )?
794 }};
795}
796
797#[doc(hidden)]
798#[macro_export]
799macro_rules! xlora_model_loader {
800 (
801 $paths:expr,
802 $dtype:expr,
803 $device:expr,
804 $layer_devices:expr,
805 $config:expr,
806 $loader:expr,
807 $silent:expr,
808 $mapper:expr,
809 $loading_isq:expr,
810 $real_device:expr,
811 $multi_progress:expr,
812 $matformer_config:expr,
813 ) => {{
814 let $crate::pipeline::AdapterPaths::XLora {
816 adapter_configs,
817 adapter_safetensors,
818 classifier_path,
819 xlora_order,
820 xlora_config,
821 lora_preload_adapter_info: _,
822 } = $paths.get_adapter_paths()
823 else {
824 unreachable!()
825 };
826
827 let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
828 safetensors_paths.push(classifier_path.as_ref().unwrap());
829 let get_device_for_tensor =
830 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
831
832 let vb = from_mmaped_safetensors(
833 safetensors_paths
834 .iter()
835 .map(|x| (*x).to_owned())
836 .collect::<Vec<_>>(),
837 adapter_safetensors
838 .as_ref()
839 .unwrap()
840 .iter()
841 .map(|(_, x)| (*x).to_owned())
842 .collect::<Vec<_>>(),
843 $dtype,
844 $device,
845 $layer_devices,
846 $silent,
847 None,
848 |_| true,
849 get_device_for_tensor,
850 )?;
851
852 $loader.load_xlora(
853 &$config,
854 vb,
855 adapter_configs.as_ref().unwrap(),
856 Some(xlora_config.as_ref().unwrap().clone()),
857 xlora_order.as_ref().unwrap().clone(),
858 $crate::pipeline::NormalLoadingMetadata {
859 mapper: $mapper,
860 loading_isq: $loading_isq,
861 real_device: $real_device,
862 multi_progress: $multi_progress,
863 matformer_slicing_config: $matformer_config,
864 },
865 &None,
866 )?
867 }};
868}
869
870#[doc(hidden)]
871#[macro_export]
872macro_rules! lora_model_loader {
873 (
874 $paths:expr,
875 $dtype:expr,
876 $device:expr,
877 $layer_devices:expr,
878 $config:expr,
879 $loader:expr,
880 $silent:expr,
881 $mapper:expr,
882 $loading_isq:expr,
883 $loading_uqff:expr,
884 $real_device:expr,
885 $attention_mechanism:expr,
886 $is_moqe:expr,
887 $multi_progress:expr,
888 $matformer_config:expr,
889 ) => {{
890 let $crate::pipeline::AdapterPaths::Lora(lora_adapter_paths) = $paths.get_adapter_paths()
891 else {
892 unreachable!()
893 };
894
895 let regexes = if $loading_isq && $loading_uqff {
896 Some(std::sync::Arc::new(if $is_moqe {
898 $loader.isq_layer_regexes_moqe(&$config)?
899 } else {
900 $loader.isq_layer_regexes(&$config)?
901 }))
902 } else {
903 None
904 };
905 let get_device_for_tensor =
906 $loader.get_device_for_tensor(&$config, &*$mapper, $loading_isq)?;
907
908 let vb = from_mmaped_safetensors(
909 $paths.get_weight_filenames().to_vec(),
910 Vec::new(),
911 $dtype,
912 $device,
913 $layer_devices,
914 $silent,
915 regexes,
916 |_| true, get_device_for_tensor.clone(),
918 )?;
919
920 for $crate::pipeline::LoraAdapterPaths {
921 adapter_path,
922 lora_config,
923 } in lora_adapter_paths
924 {
925 let lora_vb = from_mmaped_safetensors(
926 vec![adapter_path.clone()],
927 Vec::new(),
928 $dtype,
929 $device,
930 $layer_devices,
931 $silent,
932 None,
933 |_| true,
934 get_device_for_tensor.clone(),
935 )?;
936
937 mistralrs_quant::push_applied_lora(mistralrs_quant::LoraAdapter {
938 config: lora_config.clone(),
939 weights: lora_vb,
940 });
941 }
942
943 $loader.load(
944 &$config,
945 vb,
946 $crate::pipeline::NormalLoadingMetadata {
947 mapper: $mapper,
948 loading_isq: $loading_isq,
949 real_device: $real_device,
950 multi_progress: $multi_progress,
951 matformer_slicing_config: $matformer_config,
952 },
953 $attention_mechanism,
954 )?
955 }};
956}