mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these, avoids situations like `consolidated.safetensors`
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36    pub lora_config: mistralrs_quant::LoraConfig,
37    pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43    XLora {
44        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46        classifier_path: Option<PathBuf>,
47        xlora_order: Option<Ordering>,
48        xlora_config: Option<XLoraConfig>,
49        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50    },
51    Lora(Vec<LoraAdapterPaths>),
52    None,
53}
54
55pub fn get_xlora_paths(
56    base_model_id: String,
57    xlora_model_id: Option<&String>,
58    lora_adapter_ids: Option<&Vec<String>>,
59    token_source: &TokenSource,
60    revision: String,
61    xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63    match (lora_adapter_ids, xlora_model_id, xlora_order) {
64        (None, Some(xlora_id), Some(xlora_order)) => {
65            let api = {
66                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67                let mut api = ApiBuilder::from_cache(cache)
68                    .with_progress(true)
69                    .with_token(get_token(token_source)?);
70                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
71                    api = api.with_cache_dir(x.into());
72                }
73                api.build().map_err(candle_core::Error::msg)?
74            };
75            let api = api.repo(Repo::with_revision(
76                xlora_id.clone(),
77                RepoType::Model,
78                revision,
79            ));
80            let model_id = Path::new(&xlora_id);
81            let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
82            // Get the path for the xlora classifier
83            let xlora_classifier = &dir_list
84                .clone()
85                .into_iter()
86                .filter(|x| x.contains("xlora_classifier.safetensors"))
87                .collect::<Vec<_>>();
88            if xlora_classifier.len() > 1 {
89                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
90                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
91            }
92            let xlora_classifier = xlora_classifier.first();
93
94            let classifier_path = xlora_classifier
95                .map(|xlora_classifier| api_get_file!(api, xlora_classifier, model_id));
96
97            // Get the path for the xlora config by checking all for valid versions.
98            // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
99            let xlora_configs = &dir_list
100                .clone()
101                .into_iter()
102                .filter(|x| x.contains("xlora_config.json"))
103                .collect::<Vec<_>>();
104            if xlora_configs.len() > 1 {
105                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
106            }
107
108            let mut xlora_config: Option<XLoraConfig> = None;
109            let mut last_err: Option<serde_json::Error> = None;
110            for (i, config_path) in xlora_configs.iter().enumerate() {
111                if xlora_configs.len() != 1 {
112                    warn!("Selecting config: `{}`", config_path);
113                }
114                let config_path = api_get_file!(api, config_path, model_id);
115                let conf = fs::read_to_string(config_path)?;
116                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
117                match deser {
118                    Ok(conf) => {
119                        xlora_config = Some(conf);
120                        break;
121                    }
122                    Err(e) => {
123                        if i != xlora_configs.len() - 1 {
124                            warn!("Config is broken with error `{e}`");
125                        }
126                        last_err = Some(e);
127                    }
128                }
129            }
130            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
131                if let Some(last_err) = last_err {
132                    panic!("Unable to derserialize any configs. Last error: {last_err}")
133                } else {
134                    None
135                }
136            });
137
138            // If there are adapters in the ordering file, get their names and remote paths
139            let adapter_files = dir_list
140                .into_iter()
141                .filter_map(|name| {
142                    if let Some(ref adapters) = xlora_order.adapters {
143                        for adapter_name in adapters {
144                            if name.contains(adapter_name) {
145                                return Some((name, adapter_name.clone()));
146                            }
147                        }
148                    }
149                    None
150                })
151                .collect::<Vec<_>>();
152            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
153                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
154            }
155
156            // Get the local paths for each adapter
157            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
158            for (file, name) in adapter_files {
159                if let Some(paths) = adapters_paths.get_mut(&name) {
160                    paths.push(api_get_file!(api, &file, model_id));
161                } else {
162                    adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
163                }
164            }
165
166            // Sort local paths for the adapter configs and safetensors files
167            let mut adapters_configs = Vec::new();
168            let mut adapters_safetensors = Vec::new();
169            if let Some(ref adapters) = xlora_order.adapters {
170                for (i, name) in adapters.iter().enumerate() {
171                    let paths = adapters_paths
172                        .get(name)
173                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
174                    for path in paths {
175                        if path.extension().unwrap() == "safetensors" {
176                            adapters_safetensors.push((name.clone(), path.to_owned()));
177                        } else {
178                            let conf = fs::read_to_string(path)?;
179                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
180                            adapters_configs
181                                .push((((i + 1).to_string(), name.clone()), lora_config));
182                        }
183                    }
184                }
185            }
186
187            // Make sure they all match
188            if xlora_order.base_model_id
189                != *xlora_config
190                    .as_ref()
191                    .map(|cfg| &cfg.base_model_id)
192                    .unwrap_or(&base_model_id)
193                || xlora_config
194                    .as_ref()
195                    .map(|cfg| &cfg.base_model_id)
196                    .unwrap_or(&base_model_id)
197                    != &base_model_id
198            {
199                anyhow::bail!(
200                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
201                    xlora_order.base_model_id,
202                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
203                    base_model_id
204                );
205            }
206
207            let lora_preload_adapter_info =
208                // If preload adapters are specified, get their metadata like above
209                if let Some(preload_adapters) = &xlora_order.preload_adapters {
210                    let mut output = HashMap::new();
211                    for adapter in preload_adapters {
212                        // Get the names and remote paths of the files associated with this adapter
213                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
214                            .filter_map(|f| {
215                                if f.contains(&adapter.name) {
216                                    Some((f, adapter.name.clone()))
217                                } else {
218                                    None
219                                }
220                            })
221                            .collect::<Vec<_>>();
222                        if adapter_files.is_empty() {
223                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
224                        }
225                        // Get local paths for this adapter
226                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
227                        for (file, name) in adapter_files {
228                            if let Some(paths) = adapters_paths.get_mut(&name) {
229                                paths.push(api_get_file!(api, &file, model_id));
230                            } else {
231                                adapters_paths
232                                    .insert(name, vec![api_get_file!(api, &file, model_id)]);
233                            }
234                        }
235
236                        let mut config = None;
237                        let mut safetensor = None;
238
239                        // Sort local paths for the adapter configs and safetensors files
240                        let paths = adapters_paths
241                            .get(&adapter.name)
242                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
243                        for path in paths {
244                            if path.extension().unwrap() == "safetensors" {
245                                safetensor = Some(path.to_owned());
246                            } else {
247                                let conf = fs::read_to_string(path)?;
248                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
249                                config = Some(lora_config);
250                            }
251                        }
252
253                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
254                        output.insert(adapter.name.clone(), (safetensor, config));
255                    }
256                    Some(output)
257                } else {
258                    None
259                };
260
261            Ok(AdapterPaths::XLora {
262                adapter_configs: Some(adapters_configs),
263                adapter_safetensors: Some(adapters_safetensors),
264                classifier_path,
265                xlora_order: Some(xlora_order.clone()),
266                xlora_config,
267                lora_preload_adapter_info,
268            })
269        }
270        (Some(adapter_ids), None, None) => {
271            let mut lora_adapter_paths = Vec::new();
272            for adapter_id in adapter_ids {
273                info!("Loading adapter at `{adapter_id}`");
274
275                let api = {
276                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
277                    let mut api = ApiBuilder::from_cache(cache)
278                        .with_progress(true)
279                        .with_token(get_token(token_source)?);
280                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
281                        api = api.with_cache_dir(x.into());
282                    }
283                    api.build().map_err(candle_core::Error::msg)?
284                };
285                let api = api.repo(Repo::with_revision(
286                    adapter_id.clone(),
287                    RepoType::Model,
288                    revision.clone(),
289                ));
290
291                let config_path = api.get("adapter_config.json")?;
292                let adapter_path = api.get("adapter_model.safetensors")?;
293                let lora_config: mistralrs_quant::LoraConfig =
294                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
295
296                lora_adapter_paths.push(LoraAdapterPaths {
297                    lora_config,
298                    adapter_path,
299                });
300            }
301
302            Ok(AdapterPaths::Lora(lora_adapter_paths))
303        }
304        (None, None, None) => Ok(AdapterPaths::None),
305        _ => anyhow::bail!(
306            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
307        ),
308    }
309}
310
311pub fn get_model_paths(
312    revision: String,
313    token_source: &TokenSource,
314    quantized_model_id: Option<&String>,
315    quantized_filename: Option<&Vec<String>>,
316    api: &ApiRepo,
317    model_id: &Path,
318    loading_from_uqff: bool,
319) -> Result<Vec<PathBuf>> {
320    match quantized_filename {
321        Some(names) => {
322            let id = quantized_model_id.unwrap();
323            let mut files = Vec::new();
324
325            for name in names {
326                let qapi = {
327                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
328                    let mut api = ApiBuilder::from_cache(cache)
329                        .with_progress(true)
330                        .with_token(get_token(token_source)?);
331                    if let Ok(x) = std::env::var("HF_HUB_CACHE") {
332                        api = api.with_cache_dir(x.into());
333                    }
334                    api.build().map_err(candle_core::Error::msg)?
335                };
336                let qapi = qapi.repo(Repo::with_revision(
337                    id.to_string(),
338                    RepoType::Model,
339                    revision.clone(),
340                ));
341                let model_id = Path::new(&id);
342                files.push(api_get_file!(qapi, name, model_id));
343            }
344            Ok(files)
345        }
346        None => {
347            // We only match these patterns for model names
348            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
349            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
350            let pickle_match = Regex::new(PICKLE_MATCH)?;
351
352            let mut filenames = vec![];
353            let listing = api_dir_list!(api, model_id, true).filter(|x| {
354                safetensor_match.is_match(x)
355                    || pickle_match.is_match(x)
356                    || quant_safetensor_match.is_match(x)
357                    || x == UQFF_RESIDUAL_SAFETENSORS
358            });
359            let safetensors = listing
360                .clone()
361                .filter(|x| x.ends_with(".safetensors"))
362                .collect::<Vec<_>>();
363            let pickles = listing
364                .clone()
365                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
366                .collect::<Vec<_>>();
367            let uqff_residual = listing
368                .clone()
369                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
370                .collect::<Vec<_>>();
371            let files = if !safetensors.is_empty() {
372                // Always prefer safetensors
373                safetensors
374            } else if !pickles.is_empty() {
375                // Fall back to pickle
376                pickles
377            } else if !uqff_residual.is_empty() && loading_from_uqff {
378                uqff_residual
379            } else {
380                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
381            };
382            info!(
383                "Found model weight filenames {:?}",
384                files
385                    .iter()
386                    .map(|x| x.split('/').next_back().unwrap())
387                    .collect::<Vec<_>>()
388            );
389            for rfilename in files {
390                filenames.push(api_get_file!(api, &rfilename, model_id));
391            }
392            Ok(filenames)
393        }
394    }
395}
396
397/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
398/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
399/// have a `chat_template`, use the provided one.
400///
401/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
402/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
403///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
404///   is used.*
405///
406/// THE FOLLOWING IS IGNORED:
407/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
408///  the chat template is overwritten with this chat template.
409#[allow(clippy::borrowed_box)]
410pub(crate) fn get_chat_template(
411    paths: &Box<dyn ModelPaths>,
412    jinja_explicit: Option<&String>,
413    chat_template_explicit: Option<&String>,
414    chat_template_fallback: Option<&String>,
415    chat_template_ovrd: Option<String>,
416) -> ChatTemplate {
417    // Get template content, this may be overridden.
418    let template_content = if let Some(template_filename) = paths.get_template_filename() {
419        if !["jinja", "json"].contains(
420            &template_filename
421                .extension()
422                .expect("Template filename must be a file")
423                .to_string_lossy()
424                .to_string()
425                .as_str(),
426        ) {
427            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
428        }
429        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
430    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
431        // User specified a file
432        let template_filename = chat_template_fallback
433            .expect("A tokenizer config or chat template file path must be specified.");
434        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
435    } else if chat_template_ovrd.is_some() {
436        None
437    } else {
438        panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
439    };
440    let mut template: ChatTemplate = match chat_template_ovrd {
441        Some(chat_template) => {
442            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
443            info!("Using literal chat template.");
444            let mut template = ChatTemplate::default();
445            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
446            template
447        }
448        None => {
449            // Check if template_filename is a .jinja file
450            if let Some(template_filename) = paths.get_template_filename() {
451                if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
452                    info!("Using chat template from .jinja file.");
453                    let mut template = ChatTemplate::default();
454                    template.chat_template = Some(ChatTemplateValue(Either::Left(
455                        template_content.as_ref().unwrap().clone(),
456                    )));
457                    template
458                } else {
459                    serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
460                }
461            } else {
462                serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
463            }
464        }
465    };
466    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
467    if template.chat_template.is_none() {
468        if let Some(chat_template_explicit) = chat_template_explicit {
469            let ct =
470                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
471
472            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
473                ct
474            } else {
475                #[derive(Debug, serde::Deserialize)]
476                struct AutomaticTemplate {
477                    chat_template: String,
478                }
479                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
480                deser.chat_template
481            };
482
483            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
484        }
485    }
486
487    // JINJA explicit
488    if let Some(jinja_explicit) = jinja_explicit {
489        if !jinja_explicit.ends_with(".jinja") {
490            panic!("jinja_explicit must end with .jinja!");
491        }
492
493        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
494
495        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
496    }
497
498    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
499        .get_processor_config()
500        .as_ref()
501        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
502    if let Some(processor_conf) = processor_conf {
503        if processor_conf.chat_template.is_some() {
504            template.chat_template = processor_conf
505                .chat_template
506                .map(|x| ChatTemplateValue(Either::Left(x)));
507        }
508    }
509
510    #[derive(Debug, serde::Deserialize)]
511    struct SpecifiedTemplate {
512        chat_template: String,
513        bos_token: Option<String>,
514        eos_token: Option<String>,
515        unk_token: Option<String>,
516    }
517
518    if template.chat_template.is_some() {
519        return template;
520    };
521
522    match &template.chat_template {
523        Some(_) => template,
524        None => {
525            info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
526            let mut deser: HashMap<String, Value> =
527                serde_json::from_str(&template_content.unwrap()).unwrap();
528
529            match chat_template_fallback.cloned() {
530                Some(t) => {
531                    info!("Loading specified loading chat template file at `{t}`.");
532                    let templ: SpecifiedTemplate =
533                        serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
534                    deser.insert(
535                        "chat_template".to_string(),
536                        Value::String(templ.chat_template),
537                    );
538                    if templ.bos_token.is_some() {
539                        deser.insert(
540                            "bos_token".to_string(),
541                            Value::String(templ.bos_token.unwrap()),
542                        );
543                    }
544                    if templ.eos_token.is_some() {
545                        deser.insert(
546                            "eos_token".to_string(),
547                            Value::String(templ.eos_token.unwrap()),
548                        );
549                    }
550                    if templ.unk_token.is_some() {
551                        deser.insert(
552                            "unk_token".to_string(),
553                            Value::String(templ.unk_token.unwrap()),
554                        );
555                    }
556                }
557                None => {
558                    warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
559                    deser.insert("chat_template".to_string(), Value::Null);
560                }
561            }
562
563            let ser = serde_json::to_string_pretty(&deser)
564                .expect("Serialization of modified chat template failed.");
565            serde_json::from_str(&ser).unwrap()
566        }
567    }
568}
569
570mod tests {
571    #[test]
572    fn match_safetensors() -> anyhow::Result<()> {
573        use regex_automata::meta::Regex;
574
575        use super::SAFETENSOR_MATCH;
576        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
577
578        let positive_ids = [
579            "model-00001-of-00001.safetensors",
580            "model-00002-of-00002.safetensors",
581            "model-00003-of-00003.safetensors",
582            "model-00004-of-00004.safetensors",
583            "model-00005-of-00005.safetensors",
584            "model-00006-of-00006.safetensors",
585        ];
586        let negative_ids = [
587            "model-0000a-of-00002.safetensors",
588            "consolidated.safetensors",
589        ];
590        for id in positive_ids {
591            assert!(safetensor_match.is_match(id));
592        }
593        for id in negative_ids {
594            assert!(!safetensor_match.is_match(id));
595        }
596        Ok(())
597    }
598
599    #[test]
600    fn match_pickle() -> anyhow::Result<()> {
601        use regex_automata::meta::Regex;
602
603        use super::PICKLE_MATCH;
604        let pickle_match = Regex::new(PICKLE_MATCH)?;
605
606        let positive_ids = [
607            "pytorch_model-00001-of-00002.bin",
608            "pytorch_model-00002-of-00002.bin",
609        ];
610        let negative_ids = [
611            "pytorch_model-000001-of-00001.bin",
612            "pytorch_model-0000a-of-00002.bin",
613            "pytorch_model-000-of-00003.bin",
614            "pytorch_consolidated.bin",
615        ];
616        for id in positive_ids {
617            assert!(pickle_match.is_match(id));
618        }
619        for id in negative_ids {
620            assert!(!pickle_match.is_match(id));
621        }
622        Ok(())
623    }
624}