1use std::{
2 collections::HashMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10 api::sync::{ApiBuilder, ApiRepo},
11 Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18 api_dir_list, api_get_file,
19 lora::LoraConfig,
20 pipeline::{
21 chat_template::{ChatTemplate, ChatTemplateValue},
22 isq::UQFF_RESIDUAL_SAFETENSORS,
23 },
24 utils::tokens::get_token,
25 xlora_models::XLoraConfig,
26 ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34pub(crate) struct XLoraPaths {
35 pub adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
36 pub adapter_safetensors: Option<Vec<(String, PathBuf)>>,
37 pub classifier_path: Option<PathBuf>,
38 pub xlora_order: Option<Ordering>,
39 pub xlora_config: Option<XLoraConfig>,
40 pub lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
41}
42
43pub fn get_xlora_paths(
44 base_model_id: String,
45 xlora_model_id: &Option<String>,
46 token_source: &TokenSource,
47 revision: String,
48 xlora_order: &Option<Ordering>,
49) -> Result<XLoraPaths> {
50 Ok(if let Some(ref xlora_id) = xlora_model_id {
51 let api = {
52 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
53 let mut api = ApiBuilder::from_cache(cache)
54 .with_progress(true)
55 .with_token(get_token(token_source)?);
56 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
57 api = api.with_cache_dir(x.into());
58 }
59 api.build().map_err(candle_core::Error::msg)?
60 };
61 let api = api.repo(Repo::with_revision(
62 xlora_id.clone(),
63 RepoType::Model,
64 revision,
65 ));
66 let model_id = Path::new(&xlora_id);
67
68 let xlora_classifier = &api_dir_list!(api, model_id)
70 .filter(|x| x.contains("xlora_classifier.safetensors"))
71 .collect::<Vec<_>>();
72 if xlora_classifier.len() > 1 {
73 warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
74 warn!("Selected classifier: `{}`", &xlora_classifier[0]);
75 }
76 let xlora_classifier = xlora_classifier.first();
77
78 let classifier_path =
79 xlora_classifier.map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
80
81 let xlora_configs = &api_dir_list!(api, model_id)
84 .filter(|x| x.contains("xlora_config.json"))
85 .collect::<Vec<_>>();
86 if xlora_configs.len() > 1 {
87 warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
88 }
89
90 let mut xlora_config: Option<XLoraConfig> = None;
91 let mut last_err: Option<serde_json::Error> = None;
92 for (i, config_path) in xlora_configs.iter().enumerate() {
93 if xlora_configs.len() != 1 {
94 warn!("Selecting config: `{}`", config_path);
95 }
96 let config_path = api_get_file!(api, config_path, model_id);
97 let conf = fs::read_to_string(config_path)?;
98 let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
99 match deser {
100 Ok(conf) => {
101 xlora_config = Some(conf);
102 break;
103 }
104 Err(e) => {
105 if i != xlora_configs.len() - 1 {
106 warn!("Config is broken with error `{e}`");
107 }
108 last_err = Some(e);
109 }
110 }
111 }
112 let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
113 if let Some(last_err) = last_err {
114 panic!(
115 "Unable to derserialize any configs. Last error: {}",
116 last_err
117 )
118 } else {
119 None
120 }
121 });
122
123 let adapter_files = api_dir_list!(api, model_id)
125 .filter_map(|name| {
126 if let Some(ref adapters) = xlora_order.as_ref().unwrap().adapters {
127 for adapter_name in adapters {
128 if name.contains(adapter_name) {
129 return Some((name, adapter_name.clone()));
130 }
131 }
132 }
133 None
134 })
135 .collect::<Vec<_>>();
136 if adapter_files.is_empty() && xlora_order.as_ref().unwrap().adapters.is_some() {
137 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
138 }
139
140 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
142 for (file, name) in adapter_files {
143 if let Some(paths) = adapters_paths.get_mut(&name) {
144 paths.push(api_get_file!(api, &file, model_id));
145 } else {
146 adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
147 }
148 }
149
150 let mut adapters_configs = Vec::new();
152 let mut adapters_safetensors = Vec::new();
153 if let Some(ref adapters) = xlora_order.as_ref().unwrap().adapters {
154 for (i, name) in adapters.iter().enumerate() {
155 let paths = adapters_paths
156 .get(name)
157 .unwrap_or_else(|| panic!("Adapter {name} not found."));
158 for path in paths {
159 if path.extension().unwrap() == "safetensors" {
160 adapters_safetensors.push((name.clone(), path.to_owned()));
161 } else {
162 let conf = fs::read_to_string(path)?;
163 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
164 adapters_configs.push((((i + 1).to_string(), name.clone()), lora_config));
165 }
166 }
167 }
168 }
169
170 if xlora_order.as_ref().is_some_and(|order| {
172 &order.base_model_id
173 != xlora_config
174 .as_ref()
175 .map(|cfg| &cfg.base_model_id)
176 .unwrap_or(&base_model_id)
177 }) || xlora_config
178 .as_ref()
179 .map(|cfg| &cfg.base_model_id)
180 .unwrap_or(&base_model_id)
181 != &base_model_id
182 {
183 anyhow::bail!(
184 "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
185 xlora_order.as_ref().unwrap().base_model_id,
186 xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
187 base_model_id
188 );
189 }
190
191 let lora_preload_adapter_info = if let Some(xlora_order) = xlora_order {
192 if let Some(preload_adapters) = &xlora_order.preload_adapters {
194 let mut output = HashMap::new();
195 for adapter in preload_adapters {
196 let adapter_files = api_dir_list!(api, &adapter.adapter_model_id)
198 .filter_map(|f| {
199 if f.contains(&adapter.name) {
200 Some((f, adapter.name.clone()))
201 } else {
202 None
203 }
204 })
205 .collect::<Vec<_>>();
206 if adapter_files.is_empty() {
207 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
208 }
209 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
211 for (file, name) in adapter_files {
212 if let Some(paths) = adapters_paths.get_mut(&name) {
213 paths.push(api_get_file!(api, &file, model_id));
214 } else {
215 adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
216 }
217 }
218
219 let mut config = None;
220 let mut safetensor = None;
221
222 let paths = adapters_paths
224 .get(&adapter.name)
225 .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
226 for path in paths {
227 if path.extension().unwrap() == "safetensors" {
228 safetensor = Some(path.to_owned());
229 } else {
230 let conf = fs::read_to_string(path)?;
231 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
232 config = Some(lora_config);
233 }
234 }
235
236 let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
237 output.insert(adapter.name.clone(), (safetensor, config));
238 }
239 Some(output)
240 } else {
241 None
242 }
243 } else {
244 None
245 };
246
247 XLoraPaths {
248 adapter_configs: Some(adapters_configs),
249 adapter_safetensors: Some(adapters_safetensors),
250 classifier_path,
251 xlora_order: xlora_order.clone(),
252 xlora_config,
253 lora_preload_adapter_info,
254 }
255 } else {
256 XLoraPaths {
257 adapter_configs: None,
258 adapter_safetensors: None,
259 classifier_path: None,
260 xlora_order: None,
261 xlora_config: None,
262 lora_preload_adapter_info: None,
263 }
264 })
265}
266
267pub fn get_model_paths(
268 revision: String,
269 token_source: &TokenSource,
270 quantized_model_id: &Option<String>,
271 quantized_filename: &Option<Vec<String>>,
272 api: &ApiRepo,
273 model_id: &Path,
274 loading_from_uqff: bool,
275) -> Result<Vec<PathBuf>> {
276 match &quantized_filename {
277 Some(names) => {
278 let id = quantized_model_id.as_ref().unwrap();
279 let mut files = Vec::new();
280
281 for name in names {
282 let qapi = {
283 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
284 let mut api = ApiBuilder::from_cache(cache)
285 .with_progress(true)
286 .with_token(get_token(token_source)?);
287 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
288 api = api.with_cache_dir(x.into());
289 }
290 api.build().map_err(candle_core::Error::msg)?
291 };
292 let qapi = qapi.repo(Repo::with_revision(
293 id.to_string(),
294 RepoType::Model,
295 revision.clone(),
296 ));
297 let model_id = Path::new(&id);
298 files.push(api_get_file!(qapi, name, model_id));
299 }
300 Ok(files)
301 }
302 None => {
303 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
305 let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
306 let pickle_match = Regex::new(PICKLE_MATCH)?;
307
308 let mut filenames = vec![];
309 let listing = api_dir_list!(api, model_id).filter(|x| {
310 safetensor_match.is_match(x)
311 || pickle_match.is_match(x)
312 || quant_safetensor_match.is_match(x)
313 || x == UQFF_RESIDUAL_SAFETENSORS
314 });
315 let safetensors = listing
316 .clone()
317 .filter(|x| x.ends_with(".safetensors"))
318 .collect::<Vec<_>>();
319 let pickles = listing
320 .clone()
321 .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
322 .collect::<Vec<_>>();
323 let uqff_residual = listing
324 .clone()
325 .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
326 .collect::<Vec<_>>();
327 let files = if !safetensors.is_empty() {
328 safetensors
330 } else if !pickles.is_empty() {
331 pickles
333 } else if !uqff_residual.is_empty() && loading_from_uqff {
334 uqff_residual
335 } else {
336 anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
337 };
338 info!(
339 "Found model weight filenames {:?}",
340 files
341 .iter()
342 .map(|x| x.split('/').next_back().unwrap())
343 .collect::<Vec<_>>()
344 );
345 for rfilename in files {
346 filenames.push(api_get_file!(api, &rfilename, model_id));
347 }
348 Ok(filenames)
349 }
350 }
351}
352
353#[allow(clippy::borrowed_box)]
366pub(crate) fn get_chat_template(
367 paths: &Box<dyn ModelPaths>,
368 jinja_explicit: &Option<String>,
369 chat_template_explicit: &Option<String>,
370 chat_template_fallback: &Option<String>,
371 chat_template_ovrd: Option<String>,
372) -> ChatTemplate {
373 let template_content = if let Some(template_filename) = paths.get_template_filename() {
375 if !["jinja", "json"].contains(
376 &template_filename
377 .extension()
378 .expect("Template filename must be a file")
379 .to_string_lossy()
380 .to_string()
381 .as_str(),
382 ) {
383 panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
384 }
385 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
386 } else if chat_template_fallback
387 .as_ref()
388 .is_some_and(|f| f.ends_with(".json"))
389 {
390 let template_filename = chat_template_fallback
392 .as_ref()
393 .expect("A tokenizer config or chat template file path must be specified.");
394 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
395 } else if chat_template_ovrd.is_some() {
396 None
397 } else {
398 panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
399 };
400 let mut template: ChatTemplate = match chat_template_ovrd {
401 Some(chat_template) => {
402 info!("Using literal chat template.");
404 let mut template = ChatTemplate::default();
405 template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
406 template
407 }
408 None => serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap(),
409 };
410 if template.chat_template.is_none() {
412 if let Some(chat_template_explicit) = chat_template_explicit {
413 let ct =
414 fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
415
416 let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
417 ct
418 } else {
419 #[derive(Debug, serde::Deserialize)]
420 struct AutomaticTemplate {
421 chat_template: String,
422 }
423 let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
424 deser.chat_template
425 };
426
427 template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
428 }
429 }
430
431 if let Some(jinja_explicit) = jinja_explicit {
433 if !jinja_explicit.ends_with(".jinja") {
434 panic!("jinja_explicit must end with .jinja!");
435 }
436
437 let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
438
439 template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
440 }
441
442 let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
443 .get_processor_config()
444 .as_ref()
445 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
446 if let Some(processor_conf) = processor_conf {
447 if processor_conf.chat_template.is_some() {
448 template.chat_template = processor_conf
449 .chat_template
450 .map(|x| ChatTemplateValue(Either::Left(x)));
451 }
452 }
453
454 #[derive(Debug, serde::Deserialize)]
455 struct SpecifiedTemplate {
456 chat_template: String,
457 bos_token: Option<String>,
458 eos_token: Option<String>,
459 unk_token: Option<String>,
460 }
461
462 if template.chat_template.is_some() {
463 return template;
464 };
465
466 match &template.chat_template {
467 Some(_) => template,
468 None => {
469 info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
470 let mut deser: HashMap<String, Value> =
471 serde_json::from_str(&template_content.unwrap()).unwrap();
472
473 match chat_template_fallback.clone() {
474 Some(t) => {
475 info!("Loading specified loading chat template file at `{t}`.");
476 let templ: SpecifiedTemplate =
477 serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
478 deser.insert(
479 "chat_template".to_string(),
480 Value::String(templ.chat_template),
481 );
482 if templ.bos_token.is_some() {
483 deser.insert(
484 "bos_token".to_string(),
485 Value::String(templ.bos_token.unwrap()),
486 );
487 }
488 if templ.eos_token.is_some() {
489 deser.insert(
490 "eos_token".to_string(),
491 Value::String(templ.eos_token.unwrap()),
492 );
493 }
494 if templ.unk_token.is_some() {
495 deser.insert(
496 "unk_token".to_string(),
497 Value::String(templ.unk_token.unwrap()),
498 );
499 }
500 }
501 None => {
502 info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
503 deser.insert("chat_template".to_string(), Value::Null);
504 }
505 }
506
507 let ser = serde_json::to_string_pretty(&deser)
508 .expect("Serialization of modified chat template failed.");
509 serde_json::from_str(&ser).unwrap()
510 }
511 }
512}
513
514mod tests {
515 #[test]
516 fn match_safetensors() -> anyhow::Result<()> {
517 use regex_automata::meta::Regex;
518
519 use super::SAFETENSOR_MATCH;
520 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
521
522 let positive_ids = [
523 "model-00001-of-00001.safetensors",
524 "model-00002-of-00002.safetensors",
525 "model-00003-of-00003.safetensors",
526 "model-00004-of-00004.safetensors",
527 "model-00005-of-00005.safetensors",
528 "model-00006-of-00006.safetensors",
529 ];
530 let negative_ids = [
531 "model-0000a-of-00002.safetensors",
532 "consolidated.safetensors",
533 ];
534 for id in positive_ids {
535 assert!(safetensor_match.is_match(id));
536 }
537 for id in negative_ids {
538 assert!(!safetensor_match.is_match(id));
539 }
540 Ok(())
541 }
542
543 #[test]
544 fn match_pickle() -> anyhow::Result<()> {
545 use regex_automata::meta::Regex;
546
547 use super::PICKLE_MATCH;
548 let pickle_match = Regex::new(PICKLE_MATCH)?;
549
550 let positive_ids = [
551 "pytorch_model-00001-of-00002.bin",
552 "pytorch_model-00002-of-00002.bin",
553 ];
554 let negative_ids = [
555 "pytorch_model-000001-of-00001.bin",
556 "pytorch_model-0000a-of-00002.bin",
557 "pytorch_model-000-of-00003.bin",
558 "pytorch_consolidated.bin",
559 ];
560 for id in positive_ids {
561 assert!(pickle_match.is_match(id));
562 }
563 for id in negative_ids {
564 assert!(!pickle_match.is_match(id));
565 }
566 Ok(())
567 }
568}