mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these, avoids situations like `consolidated.safetensors`
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34pub(crate) struct XLoraPaths {
35    pub adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
36    pub adapter_safetensors: Option<Vec<(String, PathBuf)>>,
37    pub classifier_path: Option<PathBuf>,
38    pub xlora_order: Option<Ordering>,
39    pub xlora_config: Option<XLoraConfig>,
40    pub lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
41}
42
43pub fn get_xlora_paths(
44    base_model_id: String,
45    xlora_model_id: &Option<String>,
46    token_source: &TokenSource,
47    revision: String,
48    xlora_order: &Option<Ordering>,
49) -> Result<XLoraPaths> {
50    Ok(if let Some(ref xlora_id) = xlora_model_id {
51        let api = {
52            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
53            let mut api = ApiBuilder::from_cache(cache)
54                .with_progress(true)
55                .with_token(get_token(token_source)?);
56            if let Ok(x) = std::env::var("HF_HUB_CACHE") {
57                api = api.with_cache_dir(x.into());
58            }
59            api.build().map_err(candle_core::Error::msg)?
60        };
61        let api = api.repo(Repo::with_revision(
62            xlora_id.clone(),
63            RepoType::Model,
64            revision,
65        ));
66        let model_id = Path::new(&xlora_id);
67
68        // Get the path for the xlora classifier
69        let xlora_classifier = &api_dir_list!(api, model_id)
70            .filter(|x| x.contains("xlora_classifier.safetensors"))
71            .collect::<Vec<_>>();
72        if xlora_classifier.len() > 1 {
73            warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
74            warn!("Selected classifier: `{}`", &xlora_classifier[0]);
75        }
76        let xlora_classifier = xlora_classifier.first();
77
78        let classifier_path =
79            xlora_classifier.map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
80
81        // Get the path for the xlora config by checking all for valid versions.
82        // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
83        let xlora_configs = &api_dir_list!(api, model_id)
84            .filter(|x| x.contains("xlora_config.json"))
85            .collect::<Vec<_>>();
86        if xlora_configs.len() > 1 {
87            warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
88        }
89
90        let mut xlora_config: Option<XLoraConfig> = None;
91        let mut last_err: Option<serde_json::Error> = None;
92        for (i, config_path) in xlora_configs.iter().enumerate() {
93            if xlora_configs.len() != 1 {
94                warn!("Selecting config: `{}`", config_path);
95            }
96            let config_path = api_get_file!(api, config_path, model_id);
97            let conf = fs::read_to_string(config_path)?;
98            let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
99            match deser {
100                Ok(conf) => {
101                    xlora_config = Some(conf);
102                    break;
103                }
104                Err(e) => {
105                    if i != xlora_configs.len() - 1 {
106                        warn!("Config is broken with error `{e}`");
107                    }
108                    last_err = Some(e);
109                }
110            }
111        }
112        let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
113            if let Some(last_err) = last_err {
114                panic!(
115                    "Unable to derserialize any configs. Last error: {}",
116                    last_err
117                )
118            } else {
119                None
120            }
121        });
122
123        // If there are adapters in the ordering file, get their names and remote paths
124        let adapter_files = api_dir_list!(api, model_id)
125            .filter_map(|name| {
126                if let Some(ref adapters) = xlora_order.as_ref().unwrap().adapters {
127                    for adapter_name in adapters {
128                        if name.contains(adapter_name) {
129                            return Some((name, adapter_name.clone()));
130                        }
131                    }
132                }
133                None
134            })
135            .collect::<Vec<_>>();
136        if adapter_files.is_empty() && xlora_order.as_ref().unwrap().adapters.is_some() {
137            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
138        }
139
140        // Get the local paths for each adapter
141        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
142        for (file, name) in adapter_files {
143            if let Some(paths) = adapters_paths.get_mut(&name) {
144                paths.push(api_get_file!(api, &file, model_id));
145            } else {
146                adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
147            }
148        }
149
150        // Sort local paths for the adapter configs and safetensors files
151        let mut adapters_configs = Vec::new();
152        let mut adapters_safetensors = Vec::new();
153        if let Some(ref adapters) = xlora_order.as_ref().unwrap().adapters {
154            for (i, name) in adapters.iter().enumerate() {
155                let paths = adapters_paths
156                    .get(name)
157                    .unwrap_or_else(|| panic!("Adapter {name} not found."));
158                for path in paths {
159                    if path.extension().unwrap() == "safetensors" {
160                        adapters_safetensors.push((name.clone(), path.to_owned()));
161                    } else {
162                        let conf = fs::read_to_string(path)?;
163                        let lora_config: LoraConfig = serde_json::from_str(&conf)?;
164                        adapters_configs.push((((i + 1).to_string(), name.clone()), lora_config));
165                    }
166                }
167            }
168        }
169
170        // Make sure they all match
171        if xlora_order.as_ref().is_some_and(|order| {
172            &order.base_model_id
173                != xlora_config
174                    .as_ref()
175                    .map(|cfg| &cfg.base_model_id)
176                    .unwrap_or(&base_model_id)
177        }) || xlora_config
178            .as_ref()
179            .map(|cfg| &cfg.base_model_id)
180            .unwrap_or(&base_model_id)
181            != &base_model_id
182        {
183            anyhow::bail!(
184                "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
185                xlora_order.as_ref().unwrap().base_model_id,
186                xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
187                base_model_id
188            );
189        }
190
191        let lora_preload_adapter_info = if let Some(xlora_order) = xlora_order {
192            // If preload adapters are specified, get their metadata like above
193            if let Some(preload_adapters) = &xlora_order.preload_adapters {
194                let mut output = HashMap::new();
195                for adapter in preload_adapters {
196                    // Get the names and remote paths of the files associated with this adapter
197                    let adapter_files = api_dir_list!(api, &adapter.adapter_model_id)
198                        .filter_map(|f| {
199                            if f.contains(&adapter.name) {
200                                Some((f, adapter.name.clone()))
201                            } else {
202                                None
203                            }
204                        })
205                        .collect::<Vec<_>>();
206                    if adapter_files.is_empty() {
207                        anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
208                    }
209                    // Get local paths for this adapter
210                    let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
211                    for (file, name) in adapter_files {
212                        if let Some(paths) = adapters_paths.get_mut(&name) {
213                            paths.push(api_get_file!(api, &file, model_id));
214                        } else {
215                            adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
216                        }
217                    }
218
219                    let mut config = None;
220                    let mut safetensor = None;
221
222                    // Sort local paths for the adapter configs and safetensors files
223                    let paths = adapters_paths
224                        .get(&adapter.name)
225                        .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
226                    for path in paths {
227                        if path.extension().unwrap() == "safetensors" {
228                            safetensor = Some(path.to_owned());
229                        } else {
230                            let conf = fs::read_to_string(path)?;
231                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
232                            config = Some(lora_config);
233                        }
234                    }
235
236                    let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
237                    output.insert(adapter.name.clone(), (safetensor, config));
238                }
239                Some(output)
240            } else {
241                None
242            }
243        } else {
244            None
245        };
246
247        XLoraPaths {
248            adapter_configs: Some(adapters_configs),
249            adapter_safetensors: Some(adapters_safetensors),
250            classifier_path,
251            xlora_order: xlora_order.clone(),
252            xlora_config,
253            lora_preload_adapter_info,
254        }
255    } else {
256        XLoraPaths {
257            adapter_configs: None,
258            adapter_safetensors: None,
259            classifier_path: None,
260            xlora_order: None,
261            xlora_config: None,
262            lora_preload_adapter_info: None,
263        }
264    })
265}
266
267pub fn get_model_paths(
268    revision: String,
269    token_source: &TokenSource,
270    quantized_model_id: &Option<String>,
271    quantized_filename: &Option<Vec<String>>,
272    api: &ApiRepo,
273    model_id: &Path,
274    loading_from_uqff: bool,
275) -> Result<Vec<PathBuf>> {
276    match &quantized_filename {
277        Some(names) => {
278            let id = quantized_model_id.as_ref().unwrap();
279            let mut files = Vec::new();
280
281            for name in names {
282                let qapi = {
283                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
284                    let mut api = ApiBuilder::from_cache(cache)
285                        .with_progress(true)
286                        .with_token(get_token(token_source)?);
287                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
288                        api = api.with_cache_dir(x.into());
289                    }
290                    api.build().map_err(candle_core::Error::msg)?
291                };
292                let qapi = qapi.repo(Repo::with_revision(
293                    id.to_string(),
294                    RepoType::Model,
295                    revision.clone(),
296                ));
297                let model_id = Path::new(&id);
298                files.push(api_get_file!(qapi, name, model_id));
299            }
300            Ok(files)
301        }
302        None => {
303            // We only match these patterns for model names
304            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
305            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
306            let pickle_match = Regex::new(PICKLE_MATCH)?;
307
308            let mut filenames = vec![];
309            let listing = api_dir_list!(api, model_id).filter(|x| {
310                safetensor_match.is_match(x)
311                    || pickle_match.is_match(x)
312                    || quant_safetensor_match.is_match(x)
313                    || x == UQFF_RESIDUAL_SAFETENSORS
314            });
315            let safetensors = listing
316                .clone()
317                .filter(|x| x.ends_with(".safetensors"))
318                .collect::<Vec<_>>();
319            let pickles = listing
320                .clone()
321                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
322                .collect::<Vec<_>>();
323            let uqff_residual = listing
324                .clone()
325                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
326                .collect::<Vec<_>>();
327            let files = if !safetensors.is_empty() {
328                // Always prefer safetensors
329                safetensors
330            } else if !pickles.is_empty() {
331                // Fall back to pickle
332                pickles
333            } else if !uqff_residual.is_empty() && loading_from_uqff {
334                uqff_residual
335            } else {
336                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
337            };
338            info!(
339                "Found model weight filenames {:?}",
340                files
341                    .iter()
342                    .map(|x| x.split('/').next_back().unwrap())
343                    .collect::<Vec<_>>()
344            );
345            for rfilename in files {
346                filenames.push(api_get_file!(api, &rfilename, model_id));
347            }
348            Ok(filenames)
349        }
350    }
351}
352
353/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
354/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
355/// have a `chat_template`, use the provided one.
356///
357/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
358/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
359///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
360///   is used.*
361///
362/// THE FOLLOWING IS IGNORED:
363/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
364///  the chat template is overwritten with this chat template.
365#[allow(clippy::borrowed_box)]
366pub(crate) fn get_chat_template(
367    paths: &Box<dyn ModelPaths>,
368    jinja_explicit: &Option<String>,
369    chat_template_explicit: &Option<String>,
370    chat_template_fallback: &Option<String>,
371    chat_template_ovrd: Option<String>,
372) -> ChatTemplate {
373    // Get template content, this may be overridden.
374    let template_content = if let Some(template_filename) = paths.get_template_filename() {
375        if !["jinja", "json"].contains(
376            &template_filename
377                .extension()
378                .expect("Template filename must be a file")
379                .to_string_lossy()
380                .to_string()
381                .as_str(),
382        ) {
383            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
384        }
385        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
386    } else if chat_template_fallback
387        .as_ref()
388        .is_some_and(|f| f.ends_with(".json"))
389    {
390        // User specified a file
391        let template_filename = chat_template_fallback
392            .as_ref()
393            .expect("A tokenizer config or chat template file path must be specified.");
394        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
395    } else if chat_template_ovrd.is_some() {
396        None
397    } else {
398        panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
399    };
400    let mut template: ChatTemplate = match chat_template_ovrd {
401        Some(chat_template) => {
402            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
403            info!("Using literal chat template.");
404            let mut template = ChatTemplate::default();
405            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
406            template
407        }
408        None => serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap(),
409    };
410    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
411    if template.chat_template.is_none() {
412        if let Some(chat_template_explicit) = chat_template_explicit {
413            let ct =
414                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
415
416            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
417                ct
418            } else {
419                #[derive(Debug, serde::Deserialize)]
420                struct AutomaticTemplate {
421                    chat_template: String,
422                }
423                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
424                deser.chat_template
425            };
426
427            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
428        }
429    }
430
431    // JINJA explicit
432    if let Some(jinja_explicit) = jinja_explicit {
433        if !jinja_explicit.ends_with(".jinja") {
434            panic!("jinja_explicit must end with .jinja!");
435        }
436
437        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
438
439        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
440    }
441
442    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
443        .get_processor_config()
444        .as_ref()
445        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
446    if let Some(processor_conf) = processor_conf {
447        if processor_conf.chat_template.is_some() {
448            template.chat_template = processor_conf
449                .chat_template
450                .map(|x| ChatTemplateValue(Either::Left(x)));
451        }
452    }
453
454    #[derive(Debug, serde::Deserialize)]
455    struct SpecifiedTemplate {
456        chat_template: String,
457        bos_token: Option<String>,
458        eos_token: Option<String>,
459        unk_token: Option<String>,
460    }
461
462    if template.chat_template.is_some() {
463        return template;
464    };
465
466    match &template.chat_template {
467        Some(_) => template,
468        None => {
469            info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
470            let mut deser: HashMap<String, Value> =
471                serde_json::from_str(&template_content.unwrap()).unwrap();
472
473            match chat_template_fallback.clone() {
474                Some(t) => {
475                    info!("Loading specified loading chat template file at `{t}`.");
476                    let templ: SpecifiedTemplate =
477                        serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
478                    deser.insert(
479                        "chat_template".to_string(),
480                        Value::String(templ.chat_template),
481                    );
482                    if templ.bos_token.is_some() {
483                        deser.insert(
484                            "bos_token".to_string(),
485                            Value::String(templ.bos_token.unwrap()),
486                        );
487                    }
488                    if templ.eos_token.is_some() {
489                        deser.insert(
490                            "eos_token".to_string(),
491                            Value::String(templ.eos_token.unwrap()),
492                        );
493                    }
494                    if templ.unk_token.is_some() {
495                        deser.insert(
496                            "unk_token".to_string(),
497                            Value::String(templ.unk_token.unwrap()),
498                        );
499                    }
500                }
501                None => {
502                    info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
503                    deser.insert("chat_template".to_string(), Value::Null);
504                }
505            }
506
507            let ser = serde_json::to_string_pretty(&deser)
508                .expect("Serialization of modified chat template failed.");
509            serde_json::from_str(&ser).unwrap()
510        }
511    }
512}
513
514mod tests {
515    #[test]
516    fn match_safetensors() -> anyhow::Result<()> {
517        use regex_automata::meta::Regex;
518
519        use super::SAFETENSOR_MATCH;
520        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
521
522        let positive_ids = [
523            "model-00001-of-00001.safetensors",
524            "model-00002-of-00002.safetensors",
525            "model-00003-of-00003.safetensors",
526            "model-00004-of-00004.safetensors",
527            "model-00005-of-00005.safetensors",
528            "model-00006-of-00006.safetensors",
529        ];
530        let negative_ids = [
531            "model-0000a-of-00002.safetensors",
532            "consolidated.safetensors",
533        ];
534        for id in positive_ids {
535            assert!(safetensor_match.is_match(id));
536        }
537        for id in negative_ids {
538            assert!(!safetensor_match.is_match(id));
539        }
540        Ok(())
541    }
542
543    #[test]
544    fn match_pickle() -> anyhow::Result<()> {
545        use regex_automata::meta::Regex;
546
547        use super::PICKLE_MATCH;
548        let pickle_match = Regex::new(PICKLE_MATCH)?;
549
550        let positive_ids = [
551            "pytorch_model-00001-of-00002.bin",
552            "pytorch_model-00002-of-00002.bin",
553        ];
554        let negative_ids = [
555            "pytorch_model-000001-of-00001.bin",
556            "pytorch_model-0000a-of-00002.bin",
557            "pytorch_model-000-of-00003.bin",
558            "pytorch_consolidated.bin",
559        ];
560        for id in positive_ids {
561            assert!(pickle_match.is_match(id));
562        }
563        for id in negative_ids {
564            assert!(!pickle_match.is_match(id));
565        }
566        Ok(())
567    }
568}