1use std::{
2 collections::HashMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10 api::sync::{ApiBuilder, ApiRepo},
11 Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18 api_dir_list, api_get_file,
19 lora::LoraConfig,
20 pipeline::{
21 chat_template::{ChatTemplate, ChatTemplateValue},
22 isq::UQFF_RESIDUAL_SAFETENSORS,
23 },
24 utils::tokens::get_token,
25 xlora_models::XLoraConfig,
26 ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36 pub lora_config: mistralrs_quant::LoraConfig,
37 pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43 XLora {
44 adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45 adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46 classifier_path: Option<PathBuf>,
47 xlora_order: Option<Ordering>,
48 xlora_config: Option<XLoraConfig>,
49 lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50 },
51 Lora(Vec<LoraAdapterPaths>),
52 None,
53}
54
55pub fn get_xlora_paths(
56 base_model_id: String,
57 xlora_model_id: Option<&String>,
58 lora_adapter_ids: Option<&Vec<String>>,
59 token_source: &TokenSource,
60 revision: String,
61 xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63 match (lora_adapter_ids, xlora_model_id, xlora_order) {
64 (None, Some(xlora_id), Some(xlora_order)) => {
65 let api = {
66 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67 let mut api = ApiBuilder::from_cache(cache)
68 .with_progress(true)
69 .with_token(get_token(token_source)?);
70 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
71 api = api.with_cache_dir(x.into());
72 }
73 api.build().map_err(candle_core::Error::msg)?
74 };
75 let api = api.repo(Repo::with_revision(
76 xlora_id.clone(),
77 RepoType::Model,
78 revision,
79 ));
80 let model_id = Path::new(&xlora_id);
81
82 let xlora_classifier = &api_dir_list!(api, model_id)
84 .filter(|x| x.contains("xlora_classifier.safetensors"))
85 .collect::<Vec<_>>();
86 if xlora_classifier.len() > 1 {
87 warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
88 warn!("Selected classifier: `{}`", &xlora_classifier[0]);
89 }
90 let xlora_classifier = xlora_classifier.first();
91
92 let classifier_path = xlora_classifier
93 .map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
94
95 let xlora_configs = &api_dir_list!(api, model_id)
98 .filter(|x| x.contains("xlora_config.json"))
99 .collect::<Vec<_>>();
100 if xlora_configs.len() > 1 {
101 warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
102 }
103
104 let mut xlora_config: Option<XLoraConfig> = None;
105 let mut last_err: Option<serde_json::Error> = None;
106 for (i, config_path) in xlora_configs.iter().enumerate() {
107 if xlora_configs.len() != 1 {
108 warn!("Selecting config: `{}`", config_path);
109 }
110 let config_path = api_get_file!(api, config_path, model_id);
111 let conf = fs::read_to_string(config_path)?;
112 let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
113 match deser {
114 Ok(conf) => {
115 xlora_config = Some(conf);
116 break;
117 }
118 Err(e) => {
119 if i != xlora_configs.len() - 1 {
120 warn!("Config is broken with error `{e}`");
121 }
122 last_err = Some(e);
123 }
124 }
125 }
126 let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
127 if let Some(last_err) = last_err {
128 panic!(
129 "Unable to derserialize any configs. Last error: {}",
130 last_err
131 )
132 } else {
133 None
134 }
135 });
136
137 let adapter_files = api_dir_list!(api, model_id)
139 .filter_map(|name| {
140 if let Some(ref adapters) = xlora_order.adapters {
141 for adapter_name in adapters {
142 if name.contains(adapter_name) {
143 return Some((name, adapter_name.clone()));
144 }
145 }
146 }
147 None
148 })
149 .collect::<Vec<_>>();
150 if adapter_files.is_empty() && xlora_order.adapters.is_some() {
151 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
152 }
153
154 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
156 for (file, name) in adapter_files {
157 if let Some(paths) = adapters_paths.get_mut(&name) {
158 paths.push(api_get_file!(api, &file, model_id));
159 } else {
160 adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
161 }
162 }
163
164 let mut adapters_configs = Vec::new();
166 let mut adapters_safetensors = Vec::new();
167 if let Some(ref adapters) = xlora_order.adapters {
168 for (i, name) in adapters.iter().enumerate() {
169 let paths = adapters_paths
170 .get(name)
171 .unwrap_or_else(|| panic!("Adapter {name} not found."));
172 for path in paths {
173 if path.extension().unwrap() == "safetensors" {
174 adapters_safetensors.push((name.clone(), path.to_owned()));
175 } else {
176 let conf = fs::read_to_string(path)?;
177 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
178 adapters_configs
179 .push((((i + 1).to_string(), name.clone()), lora_config));
180 }
181 }
182 }
183 }
184
185 if xlora_order.base_model_id
187 != *xlora_config
188 .as_ref()
189 .map(|cfg| &cfg.base_model_id)
190 .unwrap_or(&base_model_id)
191 || xlora_config
192 .as_ref()
193 .map(|cfg| &cfg.base_model_id)
194 .unwrap_or(&base_model_id)
195 != &base_model_id
196 {
197 anyhow::bail!(
198 "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
199 xlora_order.base_model_id,
200 xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
201 base_model_id
202 );
203 }
204
205 let lora_preload_adapter_info =
206 if let Some(preload_adapters) = &xlora_order.preload_adapters {
208 let mut output = HashMap::new();
209 for adapter in preload_adapters {
210 let adapter_files = api_dir_list!(api, &adapter.adapter_model_id)
212 .filter_map(|f| {
213 if f.contains(&adapter.name) {
214 Some((f, adapter.name.clone()))
215 } else {
216 None
217 }
218 })
219 .collect::<Vec<_>>();
220 if adapter_files.is_empty() {
221 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
222 }
223 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
225 for (file, name) in adapter_files {
226 if let Some(paths) = adapters_paths.get_mut(&name) {
227 paths.push(api_get_file!(api, &file, model_id));
228 } else {
229 adapters_paths
230 .insert(name, vec![api_get_file!(api, &file, model_id)]);
231 }
232 }
233
234 let mut config = None;
235 let mut safetensor = None;
236
237 let paths = adapters_paths
239 .get(&adapter.name)
240 .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
241 for path in paths {
242 if path.extension().unwrap() == "safetensors" {
243 safetensor = Some(path.to_owned());
244 } else {
245 let conf = fs::read_to_string(path)?;
246 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
247 config = Some(lora_config);
248 }
249 }
250
251 let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
252 output.insert(adapter.name.clone(), (safetensor, config));
253 }
254 Some(output)
255 } else {
256 None
257 };
258
259 Ok(AdapterPaths::XLora {
260 adapter_configs: Some(adapters_configs),
261 adapter_safetensors: Some(adapters_safetensors),
262 classifier_path,
263 xlora_order: Some(xlora_order.clone()),
264 xlora_config,
265 lora_preload_adapter_info,
266 })
267 }
268 (Some(adapter_ids), None, None) => {
269 let mut lora_adapter_paths = Vec::new();
270 for adapter_id in adapter_ids {
271 info!("Loading adapter at `{adapter_id}`");
272
273 let api = {
274 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
275 let mut api = ApiBuilder::from_cache(cache)
276 .with_progress(true)
277 .with_token(get_token(token_source)?);
278 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
279 api = api.with_cache_dir(x.into());
280 }
281 api.build().map_err(candle_core::Error::msg)?
282 };
283 let api = api.repo(Repo::with_revision(
284 adapter_id.clone(),
285 RepoType::Model,
286 revision.clone(),
287 ));
288
289 let config_path = api.get("adapter_config.json")?;
290 let adapter_path = api.get("adapter_model.safetensors")?;
291 let lora_config: mistralrs_quant::LoraConfig =
292 serde_json::from_str(&fs::read_to_string(config_path)?)?;
293
294 lora_adapter_paths.push(LoraAdapterPaths {
295 lora_config,
296 adapter_path,
297 });
298 }
299
300 Ok(AdapterPaths::Lora(lora_adapter_paths))
301 }
302 (None, None, None) => Ok(AdapterPaths::None),
303 _ => anyhow::bail!(
304 "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
305 ),
306 }
307}
308
309pub fn get_model_paths(
310 revision: String,
311 token_source: &TokenSource,
312 quantized_model_id: Option<&String>,
313 quantized_filename: Option<&Vec<String>>,
314 api: &ApiRepo,
315 model_id: &Path,
316 loading_from_uqff: bool,
317) -> Result<Vec<PathBuf>> {
318 match quantized_filename {
319 Some(names) => {
320 let id = quantized_model_id.unwrap();
321 let mut files = Vec::new();
322
323 for name in names {
324 let qapi = {
325 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
326 let mut api = ApiBuilder::from_cache(cache)
327 .with_progress(true)
328 .with_token(get_token(token_source)?);
329 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
330 api = api.with_cache_dir(x.into());
331 }
332 api.build().map_err(candle_core::Error::msg)?
333 };
334 let qapi = qapi.repo(Repo::with_revision(
335 id.to_string(),
336 RepoType::Model,
337 revision.clone(),
338 ));
339 let model_id = Path::new(&id);
340 files.push(api_get_file!(qapi, name, model_id));
341 }
342 Ok(files)
343 }
344 None => {
345 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
347 let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
348 let pickle_match = Regex::new(PICKLE_MATCH)?;
349
350 let mut filenames = vec![];
351 let listing = api_dir_list!(api, model_id).filter(|x| {
352 safetensor_match.is_match(x)
353 || pickle_match.is_match(x)
354 || quant_safetensor_match.is_match(x)
355 || x == UQFF_RESIDUAL_SAFETENSORS
356 });
357 let safetensors = listing
358 .clone()
359 .filter(|x| x.ends_with(".safetensors"))
360 .collect::<Vec<_>>();
361 let pickles = listing
362 .clone()
363 .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
364 .collect::<Vec<_>>();
365 let uqff_residual = listing
366 .clone()
367 .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
368 .collect::<Vec<_>>();
369 let files = if !safetensors.is_empty() {
370 safetensors
372 } else if !pickles.is_empty() {
373 pickles
375 } else if !uqff_residual.is_empty() && loading_from_uqff {
376 uqff_residual
377 } else {
378 anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
379 };
380 info!(
381 "Found model weight filenames {:?}",
382 files
383 .iter()
384 .map(|x| x.split('/').next_back().unwrap())
385 .collect::<Vec<_>>()
386 );
387 for rfilename in files {
388 filenames.push(api_get_file!(api, &rfilename, model_id));
389 }
390 Ok(filenames)
391 }
392 }
393}
394
395#[allow(clippy::borrowed_box)]
408pub(crate) fn get_chat_template(
409 paths: &Box<dyn ModelPaths>,
410 jinja_explicit: Option<&String>,
411 chat_template_explicit: Option<&String>,
412 chat_template_fallback: Option<&String>,
413 chat_template_ovrd: Option<String>,
414) -> ChatTemplate {
415 let template_content = if let Some(template_filename) = paths.get_template_filename() {
417 if !["jinja", "json"].contains(
418 &template_filename
419 .extension()
420 .expect("Template filename must be a file")
421 .to_string_lossy()
422 .to_string()
423 .as_str(),
424 ) {
425 panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
426 }
427 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
428 } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
429 let template_filename = chat_template_fallback
431 .expect("A tokenizer config or chat template file path must be specified.");
432 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
433 } else if chat_template_ovrd.is_some() {
434 None
435 } else {
436 panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
437 };
438 let mut template: ChatTemplate = match chat_template_ovrd {
439 Some(chat_template) => {
440 info!("Using literal chat template.");
442 let mut template = ChatTemplate::default();
443 template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
444 template
445 }
446 None => serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap(),
447 };
448 if template.chat_template.is_none() {
450 if let Some(chat_template_explicit) = chat_template_explicit {
451 let ct =
452 fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
453
454 let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
455 ct
456 } else {
457 #[derive(Debug, serde::Deserialize)]
458 struct AutomaticTemplate {
459 chat_template: String,
460 }
461 let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
462 deser.chat_template
463 };
464
465 template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
466 }
467 }
468
469 if let Some(jinja_explicit) = jinja_explicit {
471 if !jinja_explicit.ends_with(".jinja") {
472 panic!("jinja_explicit must end with .jinja!");
473 }
474
475 let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
476
477 template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
478 }
479
480 let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
481 .get_processor_config()
482 .as_ref()
483 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
484 if let Some(processor_conf) = processor_conf {
485 if processor_conf.chat_template.is_some() {
486 template.chat_template = processor_conf
487 .chat_template
488 .map(|x| ChatTemplateValue(Either::Left(x)));
489 }
490 }
491
492 #[derive(Debug, serde::Deserialize)]
493 struct SpecifiedTemplate {
494 chat_template: String,
495 bos_token: Option<String>,
496 eos_token: Option<String>,
497 unk_token: Option<String>,
498 }
499
500 if template.chat_template.is_some() {
501 return template;
502 };
503
504 match &template.chat_template {
505 Some(_) => template,
506 None => {
507 info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
508 let mut deser: HashMap<String, Value> =
509 serde_json::from_str(&template_content.unwrap()).unwrap();
510
511 match chat_template_fallback.cloned() {
512 Some(t) => {
513 info!("Loading specified loading chat template file at `{t}`.");
514 let templ: SpecifiedTemplate =
515 serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
516 deser.insert(
517 "chat_template".to_string(),
518 Value::String(templ.chat_template),
519 );
520 if templ.bos_token.is_some() {
521 deser.insert(
522 "bos_token".to_string(),
523 Value::String(templ.bos_token.unwrap()),
524 );
525 }
526 if templ.eos_token.is_some() {
527 deser.insert(
528 "eos_token".to_string(),
529 Value::String(templ.eos_token.unwrap()),
530 );
531 }
532 if templ.unk_token.is_some() {
533 deser.insert(
534 "unk_token".to_string(),
535 Value::String(templ.unk_token.unwrap()),
536 );
537 }
538 }
539 None => {
540 info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
541 deser.insert("chat_template".to_string(), Value::Null);
542 }
543 }
544
545 let ser = serde_json::to_string_pretty(&deser)
546 .expect("Serialization of modified chat template failed.");
547 serde_json::from_str(&ser).unwrap()
548 }
549 }
550}
551
552mod tests {
553 #[test]
554 fn match_safetensors() -> anyhow::Result<()> {
555 use regex_automata::meta::Regex;
556
557 use super::SAFETENSOR_MATCH;
558 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
559
560 let positive_ids = [
561 "model-00001-of-00001.safetensors",
562 "model-00002-of-00002.safetensors",
563 "model-00003-of-00003.safetensors",
564 "model-00004-of-00004.safetensors",
565 "model-00005-of-00005.safetensors",
566 "model-00006-of-00006.safetensors",
567 ];
568 let negative_ids = [
569 "model-0000a-of-00002.safetensors",
570 "consolidated.safetensors",
571 ];
572 for id in positive_ids {
573 assert!(safetensor_match.is_match(id));
574 }
575 for id in negative_ids {
576 assert!(!safetensor_match.is_match(id));
577 }
578 Ok(())
579 }
580
581 #[test]
582 fn match_pickle() -> anyhow::Result<()> {
583 use regex_automata::meta::Regex;
584
585 use super::PICKLE_MATCH;
586 let pickle_match = Regex::new(PICKLE_MATCH)?;
587
588 let positive_ids = [
589 "pytorch_model-00001-of-00002.bin",
590 "pytorch_model-00002-of-00002.bin",
591 ];
592 let negative_ids = [
593 "pytorch_model-000001-of-00001.bin",
594 "pytorch_model-0000a-of-00002.bin",
595 "pytorch_model-000-of-00003.bin",
596 "pytorch_consolidated.bin",
597 ];
598 for id in positive_ids {
599 assert!(pickle_match.is_match(id));
600 }
601 for id in negative_ids {
602 assert!(!pickle_match.is_match(id));
603 }
604 Ok(())
605 }
606}