mistralrs_core/pipeline/
auto.rs

1use super::{
2    EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingSpecificConfig, Loader, ModelKind,
3    ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, TokenSource,
4    VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
5};
6use crate::api_get_file;
7use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
8use crate::Ordering;
9use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
10use anyhow::Result;
11use candle_core::Device;
12use hf_hub::{
13    api::sync::{ApiBuilder, ApiRepo},
14    Cache, Repo, RepoType,
15};
16use serde::Deserialize;
17use std::path::Path;
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::sync::Mutex;
21use tracing::{debug, info};
22
23/// Automatically selects between a normal or vision loader based on the `architectures` field.
24pub struct AutoLoader {
25    model_id: String,
26    normal_builder: Mutex<Option<NormalLoaderBuilder>>,
27    vision_builder: Mutex<Option<VisionLoaderBuilder>>,
28    embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
29    loader: Mutex<Option<Box<dyn Loader>>>,
30    hf_cache_path: Option<PathBuf>,
31}
32
33pub struct AutoLoaderBuilder {
34    normal_cfg: NormalSpecificConfig,
35    vision_cfg: VisionSpecificConfig,
36    embedding_cfg: EmbeddingSpecificConfig,
37    chat_template: Option<String>,
38    tokenizer_json: Option<String>,
39    model_id: String,
40    jinja_explicit: Option<String>,
41    no_kv_cache: bool,
42    xlora_model_id: Option<String>,
43    xlora_order: Option<Ordering>,
44    tgt_non_granular_index: Option<usize>,
45    lora_adapter_ids: Option<Vec<String>>,
46    hf_cache_path: Option<PathBuf>,
47}
48
49impl AutoLoaderBuilder {
50    #[allow(clippy::too_many_arguments)]
51    pub fn new(
52        normal_cfg: NormalSpecificConfig,
53        vision_cfg: VisionSpecificConfig,
54        embedding_cfg: EmbeddingSpecificConfig,
55        chat_template: Option<String>,
56        tokenizer_json: Option<String>,
57        model_id: String,
58        no_kv_cache: bool,
59        jinja_explicit: Option<String>,
60    ) -> Self {
61        Self {
62            normal_cfg,
63            vision_cfg,
64            embedding_cfg,
65            chat_template,
66            tokenizer_json,
67            model_id,
68            jinja_explicit,
69            no_kv_cache,
70            xlora_model_id: None,
71            xlora_order: None,
72            tgt_non_granular_index: None,
73            lora_adapter_ids: None,
74            hf_cache_path: None,
75        }
76    }
77
78    pub fn with_xlora(
79        mut self,
80        model_id: String,
81        order: Ordering,
82        no_kv_cache: bool,
83        tgt_non_granular_index: Option<usize>,
84    ) -> Self {
85        self.xlora_model_id = Some(model_id);
86        self.xlora_order = Some(order);
87        self.no_kv_cache = no_kv_cache;
88        self.tgt_non_granular_index = tgt_non_granular_index;
89        self
90    }
91
92    pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
93        self.lora_adapter_ids = Some(adapters);
94        self
95    }
96
97    pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
98        self.hf_cache_path = Some(path);
99        self
100    }
101
102    pub fn build(self) -> Box<dyn Loader> {
103        let Self {
104            normal_cfg,
105            vision_cfg,
106            embedding_cfg,
107            chat_template,
108            tokenizer_json,
109            model_id,
110            jinja_explicit,
111            no_kv_cache,
112            xlora_model_id,
113            xlora_order,
114            tgt_non_granular_index,
115            lora_adapter_ids,
116            hf_cache_path,
117        } = self;
118
119        let mut normal_builder = NormalLoaderBuilder::new(
120            normal_cfg,
121            chat_template.clone(),
122            tokenizer_json.clone(),
123            Some(model_id.clone()),
124            no_kv_cache,
125            jinja_explicit.clone(),
126        );
127        if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
128            normal_builder =
129                normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
130        }
131        if let Some(ref adapters) = lora_adapter_ids {
132            normal_builder = normal_builder.with_lora(adapters.clone());
133        }
134        if let Some(ref path) = hf_cache_path {
135            normal_builder = normal_builder.hf_cache_path(path.clone());
136        }
137
138        let mut vision_builder = VisionLoaderBuilder::new(
139            vision_cfg,
140            chat_template,
141            tokenizer_json.clone(),
142            Some(model_id.clone()),
143            jinja_explicit,
144        );
145        if let Some(ref adapters) = lora_adapter_ids {
146            vision_builder = vision_builder.with_lora(adapters.clone());
147        }
148        if let Some(ref path) = hf_cache_path {
149            vision_builder = vision_builder.hf_cache_path(path.clone());
150        }
151
152        let mut embedding_builder =
153            EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
154        if let Some(ref adapters) = lora_adapter_ids {
155            embedding_builder = embedding_builder.with_lora(adapters.clone());
156        }
157        if let Some(ref path) = hf_cache_path {
158            embedding_builder = embedding_builder.hf_cache_path(path.clone());
159        }
160
161        Box::new(AutoLoader {
162            model_id,
163            normal_builder: Mutex::new(Some(normal_builder)),
164            vision_builder: Mutex::new(Some(vision_builder)),
165            embedding_builder: Mutex::new(Some(embedding_builder)),
166            loader: Mutex::new(None),
167            hf_cache_path,
168        })
169    }
170}
171
172#[derive(Deserialize)]
173struct AutoConfig {
174    #[serde(default)]
175    architectures: Vec<String>,
176}
177
178struct ConfigArtifacts {
179    contents: String,
180    sentence_transformers_present: bool,
181}
182
183enum Detected {
184    Normal(NormalLoaderType),
185    Vision(VisionLoaderType),
186    Embedding(Option<EmbeddingLoaderType>),
187}
188
189impl AutoLoader {
190    fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
191        let config_path = paths.get_config_filename();
192        let contents = std::fs::read_to_string(config_path)?;
193        let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path);
194        Ok(ConfigArtifacts {
195            contents,
196            sentence_transformers_present,
197        })
198    }
199
200    fn read_config_from_hf(
201        &self,
202        revision: Option<String>,
203        token_source: &TokenSource,
204        silent: bool,
205    ) -> Result<ConfigArtifacts> {
206        let cache = self
207            .hf_cache_path
208            .clone()
209            .map(Cache::new)
210            .unwrap_or_default();
211        let mut api = ApiBuilder::from_cache(cache)
212            .with_progress(!silent)
213            .with_token(get_token(token_source)?);
214        if let Ok(x) = std::env::var("HF_HUB_CACHE") {
215            api = api.with_cache_dir(x.into());
216        }
217        let api = api.build()?;
218        let revision = revision.unwrap_or_else(|| "main".to_string());
219        let api = api.repo(Repo::with_revision(
220            self.model_id.clone(),
221            RepoType::Model,
222            revision,
223        ));
224        let model_id = Path::new(&self.model_id);
225        let config_filename = api_get_file!(api, "config.json", model_id);
226        let contents = std::fs::read_to_string(&config_filename)?;
227        let sentence_transformers_present =
228            Self::has_sentence_transformers_sibling(&config_filename)
229                || Self::fetch_sentence_transformers_config(&api, model_id);
230        Ok(ConfigArtifacts {
231            contents,
232            sentence_transformers_present,
233        })
234    }
235
236    fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
237        config_path
238            .parent()
239            .map(|parent| parent.join("config_sentence_transformers.json").exists())
240            .unwrap_or(false)
241    }
242
243    fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path) -> bool {
244        if model_id.exists() {
245            return false;
246        }
247        match api.get("config_sentence_transformers.json") {
248            Ok(_) => true,
249            Err(err) => {
250                debug!(
251                    "No `config_sentence_transformers.json` found for `{}`: {err}",
252                    model_id.display()
253                );
254                false
255            }
256        }
257    }
258
259    fn detect(&self, config: &str, allow_embedding: bool) -> Result<Detected> {
260        let cfg: AutoConfig = serde_json::from_str(config)?;
261        if allow_embedding {
262            if let Some(name) = cfg.architectures.first() {
263                if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
264                    info!(
265                        "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
266                    );
267                    return Ok(Detected::Embedding(Some(tp)));
268                }
269            }
270            info!(
271                "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
272            );
273            return Ok(Detected::Embedding(None));
274        }
275        if cfg.architectures.len() != 1 {
276            anyhow::bail!("Expected exactly one architecture in config");
277        }
278        let name = &cfg.architectures[0];
279        if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
280            return Ok(Detected::Vision(tp));
281        }
282        let tp = NormalLoaderType::from_causal_lm_name(name)?;
283        Ok(Detected::Normal(tp))
284    }
285
286    fn ensure_loader(&self, config: &str, allow_embedding: bool) -> Result<()> {
287        let mut guard = self.loader.lock().unwrap();
288        if guard.is_some() {
289            return Ok(());
290        }
291        match self.detect(config, allow_embedding)? {
292            Detected::Normal(tp) => {
293                let builder = self
294                    .normal_builder
295                    .lock()
296                    .unwrap()
297                    .take()
298                    .expect("builder taken");
299                let loader = builder.build(Some(tp)).expect("build normal");
300                *guard = Some(loader);
301            }
302            Detected::Vision(tp) => {
303                let builder = self
304                    .vision_builder
305                    .lock()
306                    .unwrap()
307                    .take()
308                    .expect("builder taken");
309                let loader = builder.build(Some(tp));
310                *guard = Some(loader);
311            }
312            Detected::Embedding(tp) => {
313                let builder = self
314                    .embedding_builder
315                    .lock()
316                    .unwrap()
317                    .take()
318                    .expect("builder taken");
319                let loader = builder.build(tp);
320                *guard = Some(loader);
321            }
322        }
323        Ok(())
324    }
325}
326
327impl Loader for AutoLoader {
328    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
329    fn load_model_from_hf(
330        &self,
331        revision: Option<String>,
332        token_source: TokenSource,
333        dtype: &dyn TryIntoDType,
334        device: &Device,
335        silent: bool,
336        mapper: DeviceMapSetting,
337        in_situ_quant: Option<IsqType>,
338        paged_attn_config: Option<PagedAttentionConfig>,
339    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
340        let _progress_guard = ProgressScopeGuard::new(silent);
341        let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
342        self.ensure_loader(&config.contents, config.sentence_transformers_present)?;
343        self.loader
344            .lock()
345            .unwrap()
346            .as_ref()
347            .unwrap()
348            .load_model_from_hf(
349                revision,
350                token_source,
351                dtype,
352                device,
353                silent,
354                mapper,
355                in_situ_quant,
356                paged_attn_config,
357            )
358    }
359
360    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
361    fn load_model_from_path(
362        &self,
363        paths: &Box<dyn ModelPaths>,
364        dtype: &dyn TryIntoDType,
365        device: &Device,
366        silent: bool,
367        mapper: DeviceMapSetting,
368        in_situ_quant: Option<IsqType>,
369        paged_attn_config: Option<PagedAttentionConfig>,
370    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
371        let _progress_guard = ProgressScopeGuard::new(silent);
372        let config = self.read_config_from_path(paths.as_ref())?;
373        self.ensure_loader(&config.contents, config.sentence_transformers_present)?;
374        self.loader
375            .lock()
376            .unwrap()
377            .as_ref()
378            .unwrap()
379            .load_model_from_path(
380                paths,
381                dtype,
382                device,
383                silent,
384                mapper,
385                in_situ_quant,
386                paged_attn_config,
387            )
388    }
389
390    fn get_id(&self) -> String {
391        self.model_id.clone()
392    }
393
394    fn get_kind(&self) -> ModelKind {
395        self.loader
396            .lock()
397            .unwrap()
398            .as_ref()
399            .map(|l| l.get_kind())
400            .unwrap_or(ModelKind::Normal)
401    }
402}