mistralrs_core/pipeline/
auto.rs

1use super::{
2    EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingSpecificConfig, Loader, ModelKind,
3    ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, TokenSource,
4    VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
5};
6use crate::api_get_file;
7use crate::utils::tokens::get_token;
8use crate::Ordering;
9use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
10use anyhow::Result;
11use candle_core::Device;
12use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
13use serde::Deserialize;
14use std::path::Path;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::Mutex;
18use tracing::info;
19
20/// Automatically selects between a normal or vision loader based on the `architectures` field.
21pub struct AutoLoader {
22    model_id: String,
23    normal_builder: Mutex<Option<NormalLoaderBuilder>>,
24    vision_builder: Mutex<Option<VisionLoaderBuilder>>,
25    embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
26    loader: Mutex<Option<Box<dyn Loader>>>,
27    hf_cache_path: Option<PathBuf>,
28}
29
30pub struct AutoLoaderBuilder {
31    normal_cfg: NormalSpecificConfig,
32    vision_cfg: VisionSpecificConfig,
33    embedding_cfg: EmbeddingSpecificConfig,
34    chat_template: Option<String>,
35    tokenizer_json: Option<String>,
36    model_id: String,
37    jinja_explicit: Option<String>,
38    no_kv_cache: bool,
39    xlora_model_id: Option<String>,
40    xlora_order: Option<Ordering>,
41    tgt_non_granular_index: Option<usize>,
42    lora_adapter_ids: Option<Vec<String>>,
43    hf_cache_path: Option<PathBuf>,
44}
45
46impl AutoLoaderBuilder {
47    #[allow(clippy::too_many_arguments)]
48    pub fn new(
49        normal_cfg: NormalSpecificConfig,
50        vision_cfg: VisionSpecificConfig,
51        embedding_cfg: EmbeddingSpecificConfig,
52        chat_template: Option<String>,
53        tokenizer_json: Option<String>,
54        model_id: String,
55        no_kv_cache: bool,
56        jinja_explicit: Option<String>,
57    ) -> Self {
58        Self {
59            normal_cfg,
60            vision_cfg,
61            embedding_cfg,
62            chat_template,
63            tokenizer_json,
64            model_id,
65            jinja_explicit,
66            no_kv_cache,
67            xlora_model_id: None,
68            xlora_order: None,
69            tgt_non_granular_index: None,
70            lora_adapter_ids: None,
71            hf_cache_path: None,
72        }
73    }
74
75    pub fn with_xlora(
76        mut self,
77        model_id: String,
78        order: Ordering,
79        no_kv_cache: bool,
80        tgt_non_granular_index: Option<usize>,
81    ) -> Self {
82        self.xlora_model_id = Some(model_id);
83        self.xlora_order = Some(order);
84        self.no_kv_cache = no_kv_cache;
85        self.tgt_non_granular_index = tgt_non_granular_index;
86        self
87    }
88
89    pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
90        self.lora_adapter_ids = Some(adapters);
91        self
92    }
93
94    pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
95        self.hf_cache_path = Some(path);
96        self
97    }
98
99    pub fn build(self) -> Box<dyn Loader> {
100        let Self {
101            normal_cfg,
102            vision_cfg,
103            embedding_cfg,
104            chat_template,
105            tokenizer_json,
106            model_id,
107            jinja_explicit,
108            no_kv_cache,
109            xlora_model_id,
110            xlora_order,
111            tgt_non_granular_index,
112            lora_adapter_ids,
113            hf_cache_path,
114        } = self;
115
116        let mut normal_builder = NormalLoaderBuilder::new(
117            normal_cfg,
118            chat_template.clone(),
119            tokenizer_json.clone(),
120            Some(model_id.clone()),
121            no_kv_cache,
122            jinja_explicit.clone(),
123        );
124        if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
125            normal_builder =
126                normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
127        }
128        if let Some(ref adapters) = lora_adapter_ids {
129            normal_builder = normal_builder.with_lora(adapters.clone());
130        }
131        if let Some(ref path) = hf_cache_path {
132            normal_builder = normal_builder.hf_cache_path(path.clone());
133        }
134
135        let mut vision_builder = VisionLoaderBuilder::new(
136            vision_cfg,
137            chat_template,
138            tokenizer_json.clone(),
139            Some(model_id.clone()),
140            jinja_explicit,
141        );
142        if let Some(ref adapters) = lora_adapter_ids {
143            vision_builder = vision_builder.with_lora(adapters.clone());
144        }
145        if let Some(ref path) = hf_cache_path {
146            vision_builder = vision_builder.hf_cache_path(path.clone());
147        }
148
149        let mut embedding_builder =
150            EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
151        if let Some(ref adapters) = lora_adapter_ids {
152            embedding_builder = embedding_builder.with_lora(adapters.clone());
153        }
154        if let Some(ref path) = hf_cache_path {
155            embedding_builder = embedding_builder.hf_cache_path(path.clone());
156        }
157
158        Box::new(AutoLoader {
159            model_id,
160            normal_builder: Mutex::new(Some(normal_builder)),
161            vision_builder: Mutex::new(Some(vision_builder)),
162            embedding_builder: Mutex::new(Some(embedding_builder)),
163            loader: Mutex::new(None),
164            hf_cache_path,
165        })
166    }
167}
168
169#[derive(Deserialize)]
170struct AutoConfig {
171    architectures: Vec<String>,
172}
173
174enum Detected {
175    Normal(NormalLoaderType),
176    Vision(VisionLoaderType),
177    Embedding(EmbeddingLoaderType),
178}
179
180impl AutoLoader {
181    fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<String> {
182        Ok(std::fs::read_to_string(paths.get_config_filename())?)
183    }
184
185    fn read_config_from_hf(
186        &self,
187        revision: Option<String>,
188        token_source: &TokenSource,
189        silent: bool,
190    ) -> Result<String> {
191        let cache = self
192            .hf_cache_path
193            .clone()
194            .map(Cache::new)
195            .unwrap_or_default();
196        let mut api = ApiBuilder::from_cache(cache)
197            .with_progress(!silent)
198            .with_token(get_token(token_source)?);
199        if let Ok(x) = std::env::var("HF_HUB_CACHE") {
200            api = api.with_cache_dir(x.into());
201        }
202        let api = api.build()?;
203        let revision = revision.unwrap_or_else(|| "main".to_string());
204        let api = api.repo(Repo::with_revision(
205            self.model_id.clone(),
206            RepoType::Model,
207            revision,
208        ));
209        let model_id = Path::new(&self.model_id);
210        let config_filename = api_get_file!(api, "config.json", model_id);
211        Ok(std::fs::read_to_string(config_filename)?)
212    }
213
214    fn detect(&self, config: &str) -> Result<Detected> {
215        let cfg: AutoConfig = serde_json::from_str(config)?;
216        if cfg.architectures.len() != 1 {
217            anyhow::bail!("Expected exactly one architecture in config");
218        }
219        let name = &cfg.architectures[0];
220        if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
221            return Ok(Detected::Vision(tp));
222        }
223        if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
224            return Ok(Detected::Embedding(tp));
225        }
226        let tp = NormalLoaderType::from_causal_lm_name(name)?;
227        Ok(Detected::Normal(tp))
228    }
229
230    fn ensure_loader(&self, config: &str) -> Result<()> {
231        let mut guard = self.loader.lock().unwrap();
232        if guard.is_some() {
233            return Ok(());
234        }
235        match self.detect(config)? {
236            Detected::Normal(tp) => {
237                let builder = self
238                    .normal_builder
239                    .lock()
240                    .unwrap()
241                    .take()
242                    .expect("builder taken");
243                let loader = builder.build(Some(tp)).expect("build normal");
244                *guard = Some(loader);
245            }
246            Detected::Vision(tp) => {
247                let builder = self
248                    .vision_builder
249                    .lock()
250                    .unwrap()
251                    .take()
252                    .expect("builder taken");
253                let loader = builder.build(Some(tp));
254                *guard = Some(loader);
255            }
256            Detected::Embedding(tp) => {
257                let builder = self
258                    .embedding_builder
259                    .lock()
260                    .unwrap()
261                    .take()
262                    .expect("builder taken");
263                let loader = builder.build(Some(tp));
264                *guard = Some(loader);
265            }
266        }
267        Ok(())
268    }
269}
270
271impl Loader for AutoLoader {
272    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
273    fn load_model_from_hf(
274        &self,
275        revision: Option<String>,
276        token_source: TokenSource,
277        dtype: &dyn TryIntoDType,
278        device: &Device,
279        silent: bool,
280        mapper: DeviceMapSetting,
281        in_situ_quant: Option<IsqType>,
282        paged_attn_config: Option<PagedAttentionConfig>,
283    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
284        let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
285        self.ensure_loader(&config)?;
286        self.loader
287            .lock()
288            .unwrap()
289            .as_ref()
290            .unwrap()
291            .load_model_from_hf(
292                revision,
293                token_source,
294                dtype,
295                device,
296                silent,
297                mapper,
298                in_situ_quant,
299                paged_attn_config,
300            )
301    }
302
303    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
304    fn load_model_from_path(
305        &self,
306        paths: &Box<dyn ModelPaths>,
307        dtype: &dyn TryIntoDType,
308        device: &Device,
309        silent: bool,
310        mapper: DeviceMapSetting,
311        in_situ_quant: Option<IsqType>,
312        paged_attn_config: Option<PagedAttentionConfig>,
313    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
314        let config = self.read_config_from_path(paths.as_ref())?;
315        self.ensure_loader(&config)?;
316        self.loader
317            .lock()
318            .unwrap()
319            .as_ref()
320            .unwrap()
321            .load_model_from_path(
322                paths,
323                dtype,
324                device,
325                silent,
326                mapper,
327                in_situ_quant,
328                paged_attn_config,
329            )
330    }
331
332    fn get_id(&self) -> String {
333        self.model_id.clone()
334    }
335
336    fn get_kind(&self) -> ModelKind {
337        self.loader
338            .lock()
339            .unwrap()
340            .as_ref()
341            .map(|l| l.get_kind())
342            .unwrap_or(ModelKind::Normal)
343    }
344}