mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these, avoids situations like `consolidated.safetensors`
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36    pub lora_config: mistralrs_quant::LoraConfig,
37    pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43    XLora {
44        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46        classifier_path: Option<PathBuf>,
47        xlora_order: Option<Ordering>,
48        xlora_config: Option<XLoraConfig>,
49        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50    },
51    Lora(Vec<LoraAdapterPaths>),
52    None,
53}
54
55pub fn get_xlora_paths(
56    base_model_id: String,
57    xlora_model_id: Option<&String>,
58    lora_adapter_ids: Option<&Vec<String>>,
59    token_source: &TokenSource,
60    revision: String,
61    xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63    match (lora_adapter_ids, xlora_model_id, xlora_order) {
64        (None, Some(xlora_id), Some(xlora_order)) => {
65            let api = {
66                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67                let mut api = ApiBuilder::from_cache(cache)
68                    .with_progress(true)
69                    .with_token(get_token(token_source)?);
70                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
71                    api = api.with_cache_dir(x.into());
72                }
73                api.build().map_err(candle_core::Error::msg)?
74            };
75            let api = api.repo(Repo::with_revision(
76                xlora_id.clone(),
77                RepoType::Model,
78                revision,
79            ));
80            let model_id = Path::new(&xlora_id);
81
82            // Get the path for the xlora classifier
83            let xlora_classifier = &api_dir_list!(api, model_id)
84                .filter(|x| x.contains("xlora_classifier.safetensors"))
85                .collect::<Vec<_>>();
86            if xlora_classifier.len() > 1 {
87                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
88                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
89            }
90            let xlora_classifier = xlora_classifier.first();
91
92            let classifier_path = xlora_classifier
93                .map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
94
95            // Get the path for the xlora config by checking all for valid versions.
96            // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
97            let xlora_configs = &api_dir_list!(api, model_id)
98                .filter(|x| x.contains("xlora_config.json"))
99                .collect::<Vec<_>>();
100            if xlora_configs.len() > 1 {
101                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
102            }
103
104            let mut xlora_config: Option<XLoraConfig> = None;
105            let mut last_err: Option<serde_json::Error> = None;
106            for (i, config_path) in xlora_configs.iter().enumerate() {
107                if xlora_configs.len() != 1 {
108                    warn!("Selecting config: `{}`", config_path);
109                }
110                let config_path = api_get_file!(api, config_path, model_id);
111                let conf = fs::read_to_string(config_path)?;
112                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
113                match deser {
114                    Ok(conf) => {
115                        xlora_config = Some(conf);
116                        break;
117                    }
118                    Err(e) => {
119                        if i != xlora_configs.len() - 1 {
120                            warn!("Config is broken with error `{e}`");
121                        }
122                        last_err = Some(e);
123                    }
124                }
125            }
126            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
127                if let Some(last_err) = last_err {
128                    panic!(
129                        "Unable to derserialize any configs. Last error: {}",
130                        last_err
131                    )
132                } else {
133                    None
134                }
135            });
136
137            // If there are adapters in the ordering file, get their names and remote paths
138            let adapter_files = api_dir_list!(api, model_id)
139                .filter_map(|name| {
140                    if let Some(ref adapters) = xlora_order.adapters {
141                        for adapter_name in adapters {
142                            if name.contains(adapter_name) {
143                                return Some((name, adapter_name.clone()));
144                            }
145                        }
146                    }
147                    None
148                })
149                .collect::<Vec<_>>();
150            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
151                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
152            }
153
154            // Get the local paths for each adapter
155            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
156            for (file, name) in adapter_files {
157                if let Some(paths) = adapters_paths.get_mut(&name) {
158                    paths.push(api_get_file!(api, &file, model_id));
159                } else {
160                    adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
161                }
162            }
163
164            // Sort local paths for the adapter configs and safetensors files
165            let mut adapters_configs = Vec::new();
166            let mut adapters_safetensors = Vec::new();
167            if let Some(ref adapters) = xlora_order.adapters {
168                for (i, name) in adapters.iter().enumerate() {
169                    let paths = adapters_paths
170                        .get(name)
171                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
172                    for path in paths {
173                        if path.extension().unwrap() == "safetensors" {
174                            adapters_safetensors.push((name.clone(), path.to_owned()));
175                        } else {
176                            let conf = fs::read_to_string(path)?;
177                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
178                            adapters_configs
179                                .push((((i + 1).to_string(), name.clone()), lora_config));
180                        }
181                    }
182                }
183            }
184
185            // Make sure they all match
186            if xlora_order.base_model_id
187                != *xlora_config
188                    .as_ref()
189                    .map(|cfg| &cfg.base_model_id)
190                    .unwrap_or(&base_model_id)
191                || xlora_config
192                    .as_ref()
193                    .map(|cfg| &cfg.base_model_id)
194                    .unwrap_or(&base_model_id)
195                    != &base_model_id
196            {
197                anyhow::bail!(
198                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
199                    xlora_order.base_model_id,
200                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
201                    base_model_id
202                );
203            }
204
205            let lora_preload_adapter_info =
206                // If preload adapters are specified, get their metadata like above
207                if let Some(preload_adapters) = &xlora_order.preload_adapters {
208                    let mut output = HashMap::new();
209                    for adapter in preload_adapters {
210                        // Get the names and remote paths of the files associated with this adapter
211                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id)
212                            .filter_map(|f| {
213                                if f.contains(&adapter.name) {
214                                    Some((f, adapter.name.clone()))
215                                } else {
216                                    None
217                                }
218                            })
219                            .collect::<Vec<_>>();
220                        if adapter_files.is_empty() {
221                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
222                        }
223                        // Get local paths for this adapter
224                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
225                        for (file, name) in adapter_files {
226                            if let Some(paths) = adapters_paths.get_mut(&name) {
227                                paths.push(api_get_file!(api, &file, model_id));
228                            } else {
229                                adapters_paths
230                                    .insert(name, vec![api_get_file!(api, &file, model_id)]);
231                            }
232                        }
233
234                        let mut config = None;
235                        let mut safetensor = None;
236
237                        // Sort local paths for the adapter configs and safetensors files
238                        let paths = adapters_paths
239                            .get(&adapter.name)
240                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
241                        for path in paths {
242                            if path.extension().unwrap() == "safetensors" {
243                                safetensor = Some(path.to_owned());
244                            } else {
245                                let conf = fs::read_to_string(path)?;
246                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
247                                config = Some(lora_config);
248                            }
249                        }
250
251                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
252                        output.insert(adapter.name.clone(), (safetensor, config));
253                    }
254                    Some(output)
255                } else {
256                    None
257                };
258
259            Ok(AdapterPaths::XLora {
260                adapter_configs: Some(adapters_configs),
261                adapter_safetensors: Some(adapters_safetensors),
262                classifier_path,
263                xlora_order: Some(xlora_order.clone()),
264                xlora_config,
265                lora_preload_adapter_info,
266            })
267        }
268        (Some(adapter_ids), None, None) => {
269            let mut lora_adapter_paths = Vec::new();
270            for adapter_id in adapter_ids {
271                info!("Loading adapter at `{adapter_id}`");
272
273                let api = {
274                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
275                    let mut api = ApiBuilder::from_cache(cache)
276                        .with_progress(true)
277                        .with_token(get_token(token_source)?);
278                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
279                        api = api.with_cache_dir(x.into());
280                    }
281                    api.build().map_err(candle_core::Error::msg)?
282                };
283                let api = api.repo(Repo::with_revision(
284                    adapter_id.clone(),
285                    RepoType::Model,
286                    revision.clone(),
287                ));
288
289                let config_path = api.get("adapter_config.json")?;
290                let adapter_path = api.get("adapter_model.safetensors")?;
291                let lora_config: mistralrs_quant::LoraConfig =
292                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
293
294                lora_adapter_paths.push(LoraAdapterPaths {
295                    lora_config,
296                    adapter_path,
297                });
298            }
299
300            Ok(AdapterPaths::Lora(lora_adapter_paths))
301        }
302        (None, None, None) => Ok(AdapterPaths::None),
303        _ => anyhow::bail!(
304            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
305        ),
306    }
307}
308
309pub fn get_model_paths(
310    revision: String,
311    token_source: &TokenSource,
312    quantized_model_id: Option<&String>,
313    quantized_filename: Option<&Vec<String>>,
314    api: &ApiRepo,
315    model_id: &Path,
316    loading_from_uqff: bool,
317) -> Result<Vec<PathBuf>> {
318    match quantized_filename {
319        Some(names) => {
320            let id = quantized_model_id.unwrap();
321            let mut files = Vec::new();
322
323            for name in names {
324                let qapi = {
325                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
326                    let mut api = ApiBuilder::from_cache(cache)
327                        .with_progress(true)
328                        .with_token(get_token(token_source)?);
329                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
330                        api = api.with_cache_dir(x.into());
331                    }
332                    api.build().map_err(candle_core::Error::msg)?
333                };
334                let qapi = qapi.repo(Repo::with_revision(
335                    id.to_string(),
336                    RepoType::Model,
337                    revision.clone(),
338                ));
339                let model_id = Path::new(&id);
340                files.push(api_get_file!(qapi, name, model_id));
341            }
342            Ok(files)
343        }
344        None => {
345            // We only match these patterns for model names
346            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
347            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
348            let pickle_match = Regex::new(PICKLE_MATCH)?;
349
350            let mut filenames = vec![];
351            let listing = api_dir_list!(api, model_id).filter(|x| {
352                safetensor_match.is_match(x)
353                    || pickle_match.is_match(x)
354                    || quant_safetensor_match.is_match(x)
355                    || x == UQFF_RESIDUAL_SAFETENSORS
356            });
357            let safetensors = listing
358                .clone()
359                .filter(|x| x.ends_with(".safetensors"))
360                .collect::<Vec<_>>();
361            let pickles = listing
362                .clone()
363                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
364                .collect::<Vec<_>>();
365            let uqff_residual = listing
366                .clone()
367                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
368                .collect::<Vec<_>>();
369            let files = if !safetensors.is_empty() {
370                // Always prefer safetensors
371                safetensors
372            } else if !pickles.is_empty() {
373                // Fall back to pickle
374                pickles
375            } else if !uqff_residual.is_empty() && loading_from_uqff {
376                uqff_residual
377            } else {
378                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
379            };
380            info!(
381                "Found model weight filenames {:?}",
382                files
383                    .iter()
384                    .map(|x| x.split('/').next_back().unwrap())
385                    .collect::<Vec<_>>()
386            );
387            for rfilename in files {
388                filenames.push(api_get_file!(api, &rfilename, model_id));
389            }
390            Ok(filenames)
391        }
392    }
393}
394
395/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
396/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
397/// have a `chat_template`, use the provided one.
398///
399/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
400/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
401///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
402///   is used.*
403///
404/// THE FOLLOWING IS IGNORED:
405/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
406///  the chat template is overwritten with this chat template.
407#[allow(clippy::borrowed_box)]
408pub(crate) fn get_chat_template(
409    paths: &Box<dyn ModelPaths>,
410    jinja_explicit: Option<&String>,
411    chat_template_explicit: Option<&String>,
412    chat_template_fallback: Option<&String>,
413    chat_template_ovrd: Option<String>,
414) -> ChatTemplate {
415    // Get template content, this may be overridden.
416    let template_content = if let Some(template_filename) = paths.get_template_filename() {
417        if !["jinja", "json"].contains(
418            &template_filename
419                .extension()
420                .expect("Template filename must be a file")
421                .to_string_lossy()
422                .to_string()
423                .as_str(),
424        ) {
425            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
426        }
427        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
428    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
429        // User specified a file
430        let template_filename = chat_template_fallback
431            .expect("A tokenizer config or chat template file path must be specified.");
432        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
433    } else if chat_template_ovrd.is_some() {
434        None
435    } else {
436        panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
437    };
438    let mut template: ChatTemplate = match chat_template_ovrd {
439        Some(chat_template) => {
440            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
441            info!("Using literal chat template.");
442            let mut template = ChatTemplate::default();
443            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
444            template
445        }
446        None => serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap(),
447    };
448    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
449    if template.chat_template.is_none() {
450        if let Some(chat_template_explicit) = chat_template_explicit {
451            let ct =
452                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
453
454            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
455                ct
456            } else {
457                #[derive(Debug, serde::Deserialize)]
458                struct AutomaticTemplate {
459                    chat_template: String,
460                }
461                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
462                deser.chat_template
463            };
464
465            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
466        }
467    }
468
469    // JINJA explicit
470    if let Some(jinja_explicit) = jinja_explicit {
471        if !jinja_explicit.ends_with(".jinja") {
472            panic!("jinja_explicit must end with .jinja!");
473        }
474
475        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
476
477        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
478    }
479
480    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
481        .get_processor_config()
482        .as_ref()
483        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
484    if let Some(processor_conf) = processor_conf {
485        if processor_conf.chat_template.is_some() {
486            template.chat_template = processor_conf
487                .chat_template
488                .map(|x| ChatTemplateValue(Either::Left(x)));
489        }
490    }
491
492    #[derive(Debug, serde::Deserialize)]
493    struct SpecifiedTemplate {
494        chat_template: String,
495        bos_token: Option<String>,
496        eos_token: Option<String>,
497        unk_token: Option<String>,
498    }
499
500    if template.chat_template.is_some() {
501        return template;
502    };
503
504    match &template.chat_template {
505        Some(_) => template,
506        None => {
507            info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
508            let mut deser: HashMap<String, Value> =
509                serde_json::from_str(&template_content.unwrap()).unwrap();
510
511            match chat_template_fallback.cloned() {
512                Some(t) => {
513                    info!("Loading specified loading chat template file at `{t}`.");
514                    let templ: SpecifiedTemplate =
515                        serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
516                    deser.insert(
517                        "chat_template".to_string(),
518                        Value::String(templ.chat_template),
519                    );
520                    if templ.bos_token.is_some() {
521                        deser.insert(
522                            "bos_token".to_string(),
523                            Value::String(templ.bos_token.unwrap()),
524                        );
525                    }
526                    if templ.eos_token.is_some() {
527                        deser.insert(
528                            "eos_token".to_string(),
529                            Value::String(templ.eos_token.unwrap()),
530                        );
531                    }
532                    if templ.unk_token.is_some() {
533                        deser.insert(
534                            "unk_token".to_string(),
535                            Value::String(templ.unk_token.unwrap()),
536                        );
537                    }
538                }
539                None => {
540                    info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
541                    deser.insert("chat_template".to_string(), Value::Null);
542                }
543            }
544
545            let ser = serde_json::to_string_pretty(&deser)
546                .expect("Serialization of modified chat template failed.");
547            serde_json::from_str(&ser).unwrap()
548        }
549    }
550}
551
552mod tests {
553    #[test]
554    fn match_safetensors() -> anyhow::Result<()> {
555        use regex_automata::meta::Regex;
556
557        use super::SAFETENSOR_MATCH;
558        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
559
560        let positive_ids = [
561            "model-00001-of-00001.safetensors",
562            "model-00002-of-00002.safetensors",
563            "model-00003-of-00003.safetensors",
564            "model-00004-of-00004.safetensors",
565            "model-00005-of-00005.safetensors",
566            "model-00006-of-00006.safetensors",
567        ];
568        let negative_ids = [
569            "model-0000a-of-00002.safetensors",
570            "consolidated.safetensors",
571        ];
572        for id in positive_ids {
573            assert!(safetensor_match.is_match(id));
574        }
575        for id in negative_ids {
576            assert!(!safetensor_match.is_match(id));
577        }
578        Ok(())
579    }
580
581    #[test]
582    fn match_pickle() -> anyhow::Result<()> {
583        use regex_automata::meta::Regex;
584
585        use super::PICKLE_MATCH;
586        let pickle_match = Regex::new(PICKLE_MATCH)?;
587
588        let positive_ids = [
589            "pytorch_model-00001-of-00002.bin",
590            "pytorch_model-00002-of-00002.bin",
591        ];
592        let negative_ids = [
593            "pytorch_model-000001-of-00001.bin",
594            "pytorch_model-0000a-of-00002.bin",
595            "pytorch_model-000-of-00003.bin",
596            "pytorch_consolidated.bin",
597        ];
598        for id in positive_ids {
599            assert!(pickle_match.is_match(id));
600        }
601        for id in negative_ids {
602            assert!(!pickle_match.is_match(id));
603        }
604        Ok(())
605    }
606}