1use std::{
2 collections::HashMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10 api::sync::{ApiBuilder, ApiRepo},
11 Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18 api_dir_list, api_get_file,
19 lora::LoraConfig,
20 pipeline::{
21 chat_template::{ChatTemplate, ChatTemplateValue},
22 isq::UQFF_RESIDUAL_SAFETENSORS,
23 },
24 utils::tokens::get_token,
25 xlora_models::XLoraConfig,
26 ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36 pub lora_config: mistralrs_quant::LoraConfig,
37 pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43 XLora {
44 adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45 adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46 classifier_path: Option<PathBuf>,
47 xlora_order: Option<Ordering>,
48 xlora_config: Option<XLoraConfig>,
49 lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50 },
51 Lora(Vec<LoraAdapterPaths>),
52 None,
53}
54
55pub fn get_xlora_paths(
56 base_model_id: String,
57 xlora_model_id: Option<&String>,
58 lora_adapter_ids: Option<&Vec<String>>,
59 token_source: &TokenSource,
60 revision: String,
61 xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63 match (lora_adapter_ids, xlora_model_id, xlora_order) {
64 (None, Some(xlora_id), Some(xlora_order)) => {
65 let api = {
66 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67 let mut api = ApiBuilder::from_cache(cache)
68 .with_progress(true)
69 .with_token(get_token(token_source)?);
70 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
71 api = api.with_cache_dir(x.into());
72 }
73 api.build().map_err(candle_core::Error::msg)?
74 };
75 let api = api.repo(Repo::with_revision(
76 xlora_id.clone(),
77 RepoType::Model,
78 revision,
79 ));
80 let model_id = Path::new(&xlora_id);
81 let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
82 let xlora_classifier = &dir_list
84 .clone()
85 .into_iter()
86 .filter(|x| x.contains("xlora_classifier.safetensors"))
87 .collect::<Vec<_>>();
88 if xlora_classifier.len() > 1 {
89 warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
90 warn!("Selected classifier: `{}`", &xlora_classifier[0]);
91 }
92 let xlora_classifier = xlora_classifier.first();
93
94 let classifier_path = xlora_classifier
95 .map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
96
97 let xlora_configs = &dir_list
100 .clone()
101 .into_iter()
102 .filter(|x| x.contains("xlora_config.json"))
103 .collect::<Vec<_>>();
104 if xlora_configs.len() > 1 {
105 warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
106 }
107
108 let mut xlora_config: Option<XLoraConfig> = None;
109 let mut last_err: Option<serde_json::Error> = None;
110 for (i, config_path) in xlora_configs.iter().enumerate() {
111 if xlora_configs.len() != 1 {
112 warn!("Selecting config: `{}`", config_path);
113 }
114 let config_path = api_get_file!(api, config_path, model_id);
115 let conf = fs::read_to_string(config_path)?;
116 let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
117 match deser {
118 Ok(conf) => {
119 xlora_config = Some(conf);
120 break;
121 }
122 Err(e) => {
123 if i != xlora_configs.len() - 1 {
124 warn!("Config is broken with error `{e}`");
125 }
126 last_err = Some(e);
127 }
128 }
129 }
130 let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
131 if let Some(last_err) = last_err {
132 panic!("Unable to derserialize any configs. Last error: {last_err}")
133 } else {
134 None
135 }
136 });
137
138 let adapter_files = dir_list
140 .into_iter()
141 .filter_map(|name| {
142 if let Some(ref adapters) = xlora_order.adapters {
143 for adapter_name in adapters {
144 if name.contains(adapter_name) {
145 return Some((name, adapter_name.clone()));
146 }
147 }
148 }
149 None
150 })
151 .collect::<Vec<_>>();
152 if adapter_files.is_empty() && xlora_order.adapters.is_some() {
153 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
154 }
155
156 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
158 for (file, name) in adapter_files {
159 if let Some(paths) = adapters_paths.get_mut(&name) {
160 paths.push(api_get_file!(api, &file, model_id));
161 } else {
162 adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
163 }
164 }
165
166 let mut adapters_configs = Vec::new();
168 let mut adapters_safetensors = Vec::new();
169 if let Some(ref adapters) = xlora_order.adapters {
170 for (i, name) in adapters.iter().enumerate() {
171 let paths = adapters_paths
172 .get(name)
173 .unwrap_or_else(|| panic!("Adapter {name} not found."));
174 for path in paths {
175 if path.extension().unwrap() == "safetensors" {
176 adapters_safetensors.push((name.clone(), path.to_owned()));
177 } else {
178 let conf = fs::read_to_string(path)?;
179 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
180 adapters_configs
181 .push((((i + 1).to_string(), name.clone()), lora_config));
182 }
183 }
184 }
185 }
186
187 if xlora_order.base_model_id
189 != *xlora_config
190 .as_ref()
191 .map(|cfg| &cfg.base_model_id)
192 .unwrap_or(&base_model_id)
193 || xlora_config
194 .as_ref()
195 .map(|cfg| &cfg.base_model_id)
196 .unwrap_or(&base_model_id)
197 != &base_model_id
198 {
199 anyhow::bail!(
200 "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
201 xlora_order.base_model_id,
202 xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
203 base_model_id
204 );
205 }
206
207 let lora_preload_adapter_info =
208 if let Some(preload_adapters) = &xlora_order.preload_adapters {
210 let mut output = HashMap::new();
211 for adapter in preload_adapters {
212 let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
214 .filter_map(|f| {
215 if f.contains(&adapter.name) {
216 Some((f, adapter.name.clone()))
217 } else {
218 None
219 }
220 })
221 .collect::<Vec<_>>();
222 if adapter_files.is_empty() {
223 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
224 }
225 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
227 for (file, name) in adapter_files {
228 if let Some(paths) = adapters_paths.get_mut(&name) {
229 paths.push(api_get_file!(api, &file, model_id));
230 } else {
231 adapters_paths
232 .insert(name, vec![api_get_file!(api, &file, model_id)]);
233 }
234 }
235
236 let mut config = None;
237 let mut safetensor = None;
238
239 let paths = adapters_paths
241 .get(&adapter.name)
242 .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
243 for path in paths {
244 if path.extension().unwrap() == "safetensors" {
245 safetensor = Some(path.to_owned());
246 } else {
247 let conf = fs::read_to_string(path)?;
248 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
249 config = Some(lora_config);
250 }
251 }
252
253 let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
254 output.insert(adapter.name.clone(), (safetensor, config));
255 }
256 Some(output)
257 } else {
258 None
259 };
260
261 Ok(AdapterPaths::XLora {
262 adapter_configs: Some(adapters_configs),
263 adapter_safetensors: Some(adapters_safetensors),
264 classifier_path,
265 xlora_order: Some(xlora_order.clone()),
266 xlora_config,
267 lora_preload_adapter_info,
268 })
269 }
270 (Some(adapter_ids), None, None) => {
271 let mut lora_adapter_paths = Vec::new();
272 for adapter_id in adapter_ids {
273 info!("Loading adapter at `{adapter_id}`");
274
275 let api = {
276 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
277 let mut api = ApiBuilder::from_cache(cache)
278 .with_progress(true)
279 .with_token(get_token(token_source)?);
280 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
281 api = api.with_cache_dir(x.into());
282 }
283 api.build().map_err(candle_core::Error::msg)?
284 };
285 let api = api.repo(Repo::with_revision(
286 adapter_id.clone(),
287 RepoType::Model,
288 revision.clone(),
289 ));
290
291 let config_path = api.get("adapter_config.json")?;
292 let adapter_path = api.get("adapter_model.safetensors")?;
293 let lora_config: mistralrs_quant::LoraConfig =
294 serde_json::from_str(&fs::read_to_string(config_path)?)?;
295
296 lora_adapter_paths.push(LoraAdapterPaths {
297 lora_config,
298 adapter_path,
299 });
300 }
301
302 Ok(AdapterPaths::Lora(lora_adapter_paths))
303 }
304 (None, None, None) => Ok(AdapterPaths::None),
305 _ => anyhow::bail!(
306 "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
307 ),
308 }
309}
310
311pub fn get_model_paths(
312 revision: String,
313 token_source: &TokenSource,
314 quantized_model_id: Option<&String>,
315 quantized_filename: Option<&Vec<String>>,
316 api: &ApiRepo,
317 model_id: &Path,
318 loading_from_uqff: bool,
319) -> Result<Vec<PathBuf>> {
320 match quantized_filename {
321 Some(names) => {
322 let id = quantized_model_id.unwrap();
323 let mut files = Vec::new();
324
325 for name in names {
326 let qapi = {
327 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
328 let mut api = ApiBuilder::from_cache(cache)
329 .with_progress(true)
330 .with_token(get_token(token_source)?);
331 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
332 api = api.with_cache_dir(x.into());
333 }
334 api.build().map_err(candle_core::Error::msg)?
335 };
336 let qapi = qapi.repo(Repo::with_revision(
337 id.to_string(),
338 RepoType::Model,
339 revision.clone(),
340 ));
341 let model_id = Path::new(&id);
342 files.push(api_get_file!(qapi, name, model_id));
343 }
344 Ok(files)
345 }
346 None => {
347 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
349 let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
350 let pickle_match = Regex::new(PICKLE_MATCH)?;
351
352 let mut filenames = vec![];
353 let listing = api_dir_list!(api, model_id, true).filter(|x| {
354 safetensor_match.is_match(x)
355 || pickle_match.is_match(x)
356 || quant_safetensor_match.is_match(x)
357 || x == UQFF_RESIDUAL_SAFETENSORS
358 });
359 let safetensors = listing
360 .clone()
361 .filter(|x| x.ends_with(".safetensors"))
362 .collect::<Vec<_>>();
363 let pickles = listing
364 .clone()
365 .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
366 .collect::<Vec<_>>();
367 let uqff_residual = listing
368 .clone()
369 .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
370 .collect::<Vec<_>>();
371 let files = if !safetensors.is_empty() {
372 safetensors
374 } else if !pickles.is_empty() {
375 pickles
377 } else if !uqff_residual.is_empty() && loading_from_uqff {
378 uqff_residual
379 } else {
380 anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
381 };
382 info!(
383 "Found model weight filenames {:?}",
384 files
385 .iter()
386 .map(|x| x.split('/').next_back().unwrap())
387 .collect::<Vec<_>>()
388 );
389 for rfilename in files {
390 filenames.push(api_get_file!(api, &rfilename, model_id));
391 }
392 Ok(filenames)
393 }
394 }
395}
396
397#[allow(clippy::borrowed_box)]
410pub(crate) fn get_chat_template(
411 paths: &Box<dyn ModelPaths>,
412 jinja_explicit: Option<&String>,
413 chat_template_explicit: Option<&String>,
414 chat_template_fallback: Option<&String>,
415 chat_template_ovrd: Option<String>,
416) -> ChatTemplate {
417 let template_content = if let Some(template_filename) = paths.get_template_filename() {
419 if !["jinja", "json"].contains(
420 &template_filename
421 .extension()
422 .expect("Template filename must be a file")
423 .to_string_lossy()
424 .to_string()
425 .as_str(),
426 ) {
427 panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
428 }
429 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
430 } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
431 let template_filename = chat_template_fallback
433 .expect("A tokenizer config or chat template file path must be specified.");
434 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
435 } else if chat_template_ovrd.is_some() {
436 None
437 } else {
438 panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
439 };
440 let mut template: ChatTemplate = match chat_template_ovrd {
441 Some(chat_template) => {
442 info!("Using literal chat template.");
444 let mut template = ChatTemplate::default();
445 template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
446 template
447 }
448 None => {
449 if let Some(template_filename) = paths.get_template_filename() {
451 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
452 info!("Using chat template from .jinja file.");
453 let mut template = ChatTemplate::default();
454 template.chat_template = Some(ChatTemplateValue(Either::Left(
455 template_content.as_ref().unwrap().clone(),
456 )));
457 template
458 } else {
459 serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
460 }
461 } else {
462 serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
463 }
464 }
465 };
466 if template.chat_template.is_none() {
468 if let Some(chat_template_explicit) = chat_template_explicit {
469 let ct =
470 fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
471
472 let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
473 ct
474 } else {
475 #[derive(Debug, serde::Deserialize)]
476 struct AutomaticTemplate {
477 chat_template: String,
478 }
479 let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
480 deser.chat_template
481 };
482
483 template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
484 }
485 }
486
487 if let Some(jinja_explicit) = jinja_explicit {
489 if !jinja_explicit.ends_with(".jinja") {
490 panic!("jinja_explicit must end with .jinja!");
491 }
492
493 let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
494
495 template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
496 }
497
498 let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
499 .get_processor_config()
500 .as_ref()
501 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
502 if let Some(processor_conf) = processor_conf {
503 if processor_conf.chat_template.is_some() {
504 template.chat_template = processor_conf
505 .chat_template
506 .map(|x| ChatTemplateValue(Either::Left(x)));
507 }
508 }
509
510 #[derive(Debug, serde::Deserialize)]
511 struct SpecifiedTemplate {
512 chat_template: String,
513 bos_token: Option<String>,
514 eos_token: Option<String>,
515 unk_token: Option<String>,
516 }
517
518 if template.chat_template.is_some() {
519 return template;
520 };
521
522 match &template.chat_template {
523 Some(_) => template,
524 None => {
525 info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
526 let mut deser: HashMap<String, Value> =
527 serde_json::from_str(&template_content.unwrap()).unwrap();
528
529 match chat_template_fallback.cloned() {
530 Some(t) => {
531 info!("Loading specified loading chat template file at `{t}`.");
532 let templ: SpecifiedTemplate =
533 serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
534 deser.insert(
535 "chat_template".to_string(),
536 Value::String(templ.chat_template),
537 );
538 if templ.bos_token.is_some() {
539 deser.insert(
540 "bos_token".to_string(),
541 Value::String(templ.bos_token.unwrap()),
542 );
543 }
544 if templ.eos_token.is_some() {
545 deser.insert(
546 "eos_token".to_string(),
547 Value::String(templ.eos_token.unwrap()),
548 );
549 }
550 if templ.unk_token.is_some() {
551 deser.insert(
552 "unk_token".to_string(),
553 Value::String(templ.unk_token.unwrap()),
554 );
555 }
556 }
557 None => {
558 warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
559 deser.insert("chat_template".to_string(), Value::Null);
560 }
561 }
562
563 let ser = serde_json::to_string_pretty(&deser)
564 .expect("Serialization of modified chat template failed.");
565 serde_json::from_str(&ser).unwrap()
566 }
567 }
568}
569
570mod tests {
571 #[test]
572 fn match_safetensors() -> anyhow::Result<()> {
573 use regex_automata::meta::Regex;
574
575 use super::SAFETENSOR_MATCH;
576 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
577
578 let positive_ids = [
579 "model-00001-of-00001.safetensors",
580 "model-00002-of-00002.safetensors",
581 "model-00003-of-00003.safetensors",
582 "model-00004-of-00004.safetensors",
583 "model-00005-of-00005.safetensors",
584 "model-00006-of-00006.safetensors",
585 ];
586 let negative_ids = [
587 "model-0000a-of-00002.safetensors",
588 "consolidated.safetensors",
589 ];
590 for id in positive_ids {
591 assert!(safetensor_match.is_match(id));
592 }
593 for id in negative_ids {
594 assert!(!safetensor_match.is_match(id));
595 }
596 Ok(())
597 }
598
599 #[test]
600 fn match_pickle() -> anyhow::Result<()> {
601 use regex_automata::meta::Regex;
602
603 use super::PICKLE_MATCH;
604 let pickle_match = Regex::new(PICKLE_MATCH)?;
605
606 let positive_ids = [
607 "pytorch_model-00001-of-00002.bin",
608 "pytorch_model-00002-of-00002.bin",
609 ];
610 let negative_ids = [
611 "pytorch_model-000001-of-00001.bin",
612 "pytorch_model-0000a-of-00002.bin",
613 "pytorch_model-000-of-00003.bin",
614 "pytorch_consolidated.bin",
615 ];
616 for id in positive_ids {
617 assert!(pickle_match.is_match(id));
618 }
619 for id in negative_ids {
620 assert!(!pickle_match.is_match(id));
621 }
622 Ok(())
623 }
624}