mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these, avoids situations like `consolidated.safetensors`
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36    pub lora_config: mistralrs_quant::LoraConfig,
37    pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43    XLora {
44        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46        classifier_path: Option<PathBuf>,
47        xlora_order: Option<Ordering>,
48        xlora_config: Option<XLoraConfig>,
49        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50    },
51    Lora(Vec<LoraAdapterPaths>),
52    None,
53}
54
55pub fn get_xlora_paths(
56    base_model_id: String,
57    xlora_model_id: Option<&String>,
58    lora_adapter_ids: Option<&Vec<String>>,
59    token_source: &TokenSource,
60    revision: String,
61    xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63    match (lora_adapter_ids, xlora_model_id, xlora_order) {
64        (None, Some(xlora_id), Some(xlora_order)) => {
65            let api = {
66                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67                let mut api = ApiBuilder::from_cache(cache)
68                    .with_progress(true)
69                    .with_token(get_token(token_source)?);
70                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
71                    api = api.with_cache_dir(x.into());
72                }
73                api.build().map_err(candle_core::Error::msg)?
74            };
75            let api = api.repo(Repo::with_revision(
76                xlora_id.clone(),
77                RepoType::Model,
78                revision,
79            ));
80            let model_id = Path::new(&xlora_id);
81            let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
82            // Get the path for the xlora classifier
83            let xlora_classifier = &dir_list
84                .clone()
85                .into_iter()
86                .filter(|x| x.contains("xlora_classifier.safetensors"))
87                .collect::<Vec<_>>();
88            if xlora_classifier.len() > 1 {
89                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
90                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
91            }
92            let xlora_classifier = xlora_classifier.first();
93
94            let classifier_path = xlora_classifier
95                .map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
96
97            // Get the path for the xlora config by checking all for valid versions.
98            // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
99            let xlora_configs = &dir_list
100                .clone()
101                .into_iter()
102                .filter(|x| x.contains("xlora_config.json"))
103                .collect::<Vec<_>>();
104            if xlora_configs.len() > 1 {
105                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
106            }
107
108            let mut xlora_config: Option<XLoraConfig> = None;
109            let mut last_err: Option<serde_json::Error> = None;
110            for (i, config_path) in xlora_configs.iter().enumerate() {
111                if xlora_configs.len() != 1 {
112                    warn!("Selecting config: `{}`", config_path);
113                }
114                let config_path = api_get_file!(api, config_path, model_id);
115                let conf = fs::read_to_string(config_path)?;
116                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
117                match deser {
118                    Ok(conf) => {
119                        xlora_config = Some(conf);
120                        break;
121                    }
122                    Err(e) => {
123                        if i != xlora_configs.len() - 1 {
124                            warn!("Config is broken with error `{e}`");
125                        }
126                        last_err = Some(e);
127                    }
128                }
129            }
130            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
131                if let Some(last_err) = last_err {
132                    panic!("Unable to derserialize any configs. Last error: {last_err}")
133                } else {
134                    None
135                }
136            });
137
138            // If there are adapters in the ordering file, get their names and remote paths
139            let adapter_files = dir_list
140                .into_iter()
141                .filter_map(|name| {
142                    if let Some(ref adapters) = xlora_order.adapters {
143                        for adapter_name in adapters {
144                            if name.contains(adapter_name) {
145                                return Some((name, adapter_name.clone()));
146                            }
147                        }
148                    }
149                    None
150                })
151                .collect::<Vec<_>>();
152            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
153                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
154            }
155
156            // Get the local paths for each adapter
157            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
158            for (file, name) in adapter_files {
159                if let Some(paths) = adapters_paths.get_mut(&name) {
160                    paths.push(api_get_file!(api, &file, model_id));
161                } else {
162                    adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
163                }
164            }
165
166            // Sort local paths for the adapter configs and safetensors files
167            let mut adapters_configs = Vec::new();
168            let mut adapters_safetensors = Vec::new();
169            if let Some(ref adapters) = xlora_order.adapters {
170                for (i, name) in adapters.iter().enumerate() {
171                    let paths = adapters_paths
172                        .get(name)
173                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
174                    for path in paths {
175                        if path.extension().unwrap() == "safetensors" {
176                            adapters_safetensors.push((name.clone(), path.to_owned()));
177                        } else {
178                            let conf = fs::read_to_string(path)?;
179                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
180                            adapters_configs
181                                .push((((i + 1).to_string(), name.clone()), lora_config));
182                        }
183                    }
184                }
185            }
186
187            // Make sure they all match
188            if xlora_order.base_model_id
189                != *xlora_config
190                    .as_ref()
191                    .map(|cfg| &cfg.base_model_id)
192                    .unwrap_or(&base_model_id)
193                || xlora_config
194                    .as_ref()
195                    .map(|cfg| &cfg.base_model_id)
196                    .unwrap_or(&base_model_id)
197                    != &base_model_id
198            {
199                anyhow::bail!(
200                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
201                    xlora_order.base_model_id,
202                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
203                    base_model_id
204                );
205            }
206
207            let lora_preload_adapter_info =
208                // If preload adapters are specified, get their metadata like above
209                if let Some(preload_adapters) = &xlora_order.preload_adapters {
210                    let mut output = HashMap::new();
211                    for adapter in preload_adapters {
212                        // Get the names and remote paths of the files associated with this adapter
213                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
214                            .filter_map(|f| {
215                                if f.contains(&adapter.name) {
216                                    Some((f, adapter.name.clone()))
217                                } else {
218                                    None
219                                }
220                            })
221                            .collect::<Vec<_>>();
222                        if adapter_files.is_empty() {
223                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
224                        }
225                        // Get local paths for this adapter
226                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
227                        for (file, name) in adapter_files {
228                            if let Some(paths) = adapters_paths.get_mut(&name) {
229                                paths.push(api_get_file!(api, &file, model_id));
230                            } else {
231                                adapters_paths
232                                    .insert(name, vec![api_get_file!(api, &file, model_id)]);
233                            }
234                        }
235
236                        let mut config = None;
237                        let mut safetensor = None;
238
239                        // Sort local paths for the adapter configs and safetensors files
240                        let paths = adapters_paths
241                            .get(&adapter.name)
242                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
243                        for path in paths {
244                            if path.extension().unwrap() == "safetensors" {
245                                safetensor = Some(path.to_owned());
246                            } else {
247                                let conf = fs::read_to_string(path)?;
248                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
249                                config = Some(lora_config);
250                            }
251                        }
252
253                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
254                        output.insert(adapter.name.clone(), (safetensor, config));
255                    }
256                    Some(output)
257                } else {
258                    None
259                };
260
261            Ok(AdapterPaths::XLora {
262                adapter_configs: Some(adapters_configs),
263                adapter_safetensors: Some(adapters_safetensors),
264                classifier_path,
265                xlora_order: Some(xlora_order.clone()),
266                xlora_config,
267                lora_preload_adapter_info,
268            })
269        }
270        (Some(adapter_ids), None, None) => {
271            let mut lora_adapter_paths = Vec::new();
272            for adapter_id in adapter_ids {
273                info!("Loading adapter at `{adapter_id}`");
274
275                let api = {
276                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
277                    let mut api = ApiBuilder::from_cache(cache)
278                        .with_progress(true)
279                        .with_token(get_token(token_source)?);
280                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
281                        api = api.with_cache_dir(x.into());
282                    }
283                    api.build().map_err(candle_core::Error::msg)?
284                };
285                let api = api.repo(Repo::with_revision(
286                    adapter_id.clone(),
287                    RepoType::Model,
288                    revision.clone(),
289                ));
290
291                let config_path = api.get("adapter_config.json")?;
292                let adapter_path = api.get("adapter_model.safetensors")?;
293                let lora_config: mistralrs_quant::LoraConfig =
294                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
295
296                lora_adapter_paths.push(LoraAdapterPaths {
297                    lora_config,
298                    adapter_path,
299                });
300            }
301
302            Ok(AdapterPaths::Lora(lora_adapter_paths))
303        }
304        (None, None, None) => Ok(AdapterPaths::None),
305        _ => anyhow::bail!(
306            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
307        ),
308    }
309}
310
311pub fn get_model_paths(
312    revision: String,
313    token_source: &TokenSource,
314    quantized_model_id: Option<&String>,
315    quantized_filename: Option<&Vec<String>>,
316    api: &ApiRepo,
317    model_id: &Path,
318    loading_from_uqff: bool,
319) -> Result<Vec<PathBuf>> {
320    match quantized_filename {
321        Some(names) => {
322            let id = quantized_model_id.unwrap();
323            let mut files = Vec::new();
324
325            for name in names {
326                let qapi = {
327                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
328                    let mut api = ApiBuilder::from_cache(cache)
329                        .with_progress(true)
330                        .with_token(get_token(token_source)?);
331                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
332                        api = api.with_cache_dir(x.into());
333                    }
334                    api.build().map_err(candle_core::Error::msg)?
335                };
336                let qapi = qapi.repo(Repo::with_revision(
337                    id.to_string(),
338                    RepoType::Model,
339                    revision.clone(),
340                ));
341                let model_id = Path::new(&id);
342                files.push(api_get_file!(qapi, name, model_id));
343            }
344            Ok(files)
345        }
346        None => {
347            // We only match these patterns for model names
348            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
349            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
350            let pickle_match = Regex::new(PICKLE_MATCH)?;
351
352            let mut filenames = vec![];
353            let listing = api_dir_list!(api, model_id, true).filter(|x| {
354                safetensor_match.is_match(x)
355                    || pickle_match.is_match(x)
356                    || quant_safetensor_match.is_match(x)
357                    || x == UQFF_RESIDUAL_SAFETENSORS
358            });
359            let safetensors = listing
360                .clone()
361                .filter(|x| x.ends_with(".safetensors"))
362                .collect::<Vec<_>>();
363            let pickles = listing
364                .clone()
365                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
366                .collect::<Vec<_>>();
367            let uqff_residual = listing
368                .clone()
369                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
370                .collect::<Vec<_>>();
371            let files = if !safetensors.is_empty() {
372                // Always prefer safetensors
373                safetensors
374            } else if !pickles.is_empty() {
375                // Fall back to pickle
376                pickles
377            } else if !uqff_residual.is_empty() && loading_from_uqff {
378                uqff_residual
379            } else {
380                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
381            };
382            info!(
383                "Found model weight filenames {:?}",
384                files
385                    .iter()
386                    .map(|x| x.split('/').next_back().unwrap())
387                    .collect::<Vec<_>>()
388            );
389            for rfilename in files {
390                filenames.push(api_get_file!(api, &rfilename, model_id));
391            }
392            Ok(filenames)
393        }
394    }
395}
396
397/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
398/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
399/// have a `chat_template`, use the provided one.
400///
401/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
402/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
403///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
404///   is used.*
405///
406/// THE FOLLOWING IS IGNORED:
407/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
408///  the chat template is overwritten with this chat template.
409#[allow(clippy::borrowed_box)]
410pub(crate) fn get_chat_template(
411    paths: &Box<dyn ModelPaths>,
412    jinja_explicit: Option<&String>,
413    chat_template_explicit: Option<&String>,
414    chat_template_fallback: Option<&String>,
415    chat_template_ovrd: Option<String>,
416) -> ChatTemplate {
417    // Get template content, this may be overridden.
418    let template_content = if let Some(template_filename) = paths.get_template_filename() {
419        if !["jinja", "json"].contains(
420            &template_filename
421                .extension()
422                .expect("Template filename must be a file")
423                .to_string_lossy()
424                .to_string()
425                .as_str(),
426        ) {
427            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
428        }
429        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
430    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
431        // User specified a file
432        let template_filename = chat_template_fallback
433            .expect("A tokenizer config or chat template file path must be specified.");
434        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
435    } else if chat_template_ovrd.is_some() {
436        None
437    } else {
438        panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
439    };
440    let mut template: ChatTemplate = match chat_template_ovrd {
441        Some(chat_template) => {
442            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
443            info!("Using literal chat template.");
444            let mut template = ChatTemplate::default();
445            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
446            template
447        }
448        None => serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap(),
449    };
450    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
451    if template.chat_template.is_none() {
452        if let Some(chat_template_explicit) = chat_template_explicit {
453            let ct =
454                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
455
456            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
457                ct
458            } else {
459                #[derive(Debug, serde::Deserialize)]
460                struct AutomaticTemplate {
461                    chat_template: String,
462                }
463                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
464                deser.chat_template
465            };
466
467            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
468        }
469    }
470
471    // JINJA explicit
472    if let Some(jinja_explicit) = jinja_explicit {
473        if !jinja_explicit.ends_with(".jinja") {
474            panic!("jinja_explicit must end with .jinja!");
475        }
476
477        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
478
479        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
480    }
481
482    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
483        .get_processor_config()
484        .as_ref()
485        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
486    if let Some(processor_conf) = processor_conf {
487        if processor_conf.chat_template.is_some() {
488            template.chat_template = processor_conf
489                .chat_template
490                .map(|x| ChatTemplateValue(Either::Left(x)));
491        }
492    }
493
494    #[derive(Debug, serde::Deserialize)]
495    struct SpecifiedTemplate {
496        chat_template: String,
497        bos_token: Option<String>,
498        eos_token: Option<String>,
499        unk_token: Option<String>,
500    }
501
502    if template.chat_template.is_some() {
503        return template;
504    };
505
506    match &template.chat_template {
507        Some(_) => template,
508        None => {
509            info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
510            let mut deser: HashMap<String, Value> =
511                serde_json::from_str(&template_content.unwrap()).unwrap();
512
513            match chat_template_fallback.cloned() {
514                Some(t) => {
515                    info!("Loading specified loading chat template file at `{t}`.");
516                    let templ: SpecifiedTemplate =
517                        serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
518                    deser.insert(
519                        "chat_template".to_string(),
520                        Value::String(templ.chat_template),
521                    );
522                    if templ.bos_token.is_some() {
523                        deser.insert(
524                            "bos_token".to_string(),
525                            Value::String(templ.bos_token.unwrap()),
526                        );
527                    }
528                    if templ.eos_token.is_some() {
529                        deser.insert(
530                            "eos_token".to_string(),
531                            Value::String(templ.eos_token.unwrap()),
532                        );
533                    }
534                    if templ.unk_token.is_some() {
535                        deser.insert(
536                            "unk_token".to_string(),
537                            Value::String(templ.unk_token.unwrap()),
538                        );
539                    }
540                }
541                None => {
542                    info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
543                    deser.insert("chat_template".to_string(), Value::Null);
544                }
545            }
546
547            let ser = serde_json::to_string_pretty(&deser)
548                .expect("Serialization of modified chat template failed.");
549            serde_json::from_str(&ser).unwrap()
550        }
551    }
552}
553
554mod tests {
555    #[test]
556    fn match_safetensors() -> anyhow::Result<()> {
557        use regex_automata::meta::Regex;
558
559        use super::SAFETENSOR_MATCH;
560        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
561
562        let positive_ids = [
563            "model-00001-of-00001.safetensors",
564            "model-00002-of-00002.safetensors",
565            "model-00003-of-00003.safetensors",
566            "model-00004-of-00004.safetensors",
567            "model-00005-of-00005.safetensors",
568            "model-00006-of-00006.safetensors",
569        ];
570        let negative_ids = [
571            "model-0000a-of-00002.safetensors",
572            "consolidated.safetensors",
573        ];
574        for id in positive_ids {
575            assert!(safetensor_match.is_match(id));
576        }
577        for id in negative_ids {
578            assert!(!safetensor_match.is_match(id));
579        }
580        Ok(())
581    }
582
583    #[test]
584    fn match_pickle() -> anyhow::Result<()> {
585        use regex_automata::meta::Regex;
586
587        use super::PICKLE_MATCH;
588        let pickle_match = Regex::new(PICKLE_MATCH)?;
589
590        let positive_ids = [
591            "pytorch_model-00001-of-00002.bin",
592            "pytorch_model-00002-of-00002.bin",
593        ];
594        let negative_ids = [
595            "pytorch_model-000001-of-00001.bin",
596            "pytorch_model-0000a-of-00002.bin",
597            "pytorch_model-000-of-00003.bin",
598            "pytorch_consolidated.bin",
599        ];
600        for id in positive_ids {
601            assert!(pickle_match.is_match(id));
602        }
603        for id in negative_ids {
604            assert!(!pickle_match.is_match(id));
605        }
606        Ok(())
607    }
608}