mistralrs_core/pipeline/
auto.rs

1use super::{
2    Loader, ModelKind, ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig,
3    TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
4};
5use crate::api_get_file;
6use crate::utils::tokens::get_token;
7use crate::Ordering;
8use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
9use anyhow::Result;
10use candle_core::Device;
11use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
12use serde::Deserialize;
13use std::path::Path;
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::sync::Mutex;
17use tracing::info;
18
19/// Automatically selects between a normal or vision loader based on the `architectures` field.
20pub struct AutoLoader {
21    model_id: String,
22    normal_builder: Mutex<Option<NormalLoaderBuilder>>,
23    vision_builder: Mutex<Option<VisionLoaderBuilder>>,
24    loader: Mutex<Option<Box<dyn Loader>>>,
25    hf_cache_path: Option<PathBuf>,
26}
27
28pub struct AutoLoaderBuilder {
29    normal_cfg: NormalSpecificConfig,
30    vision_cfg: VisionSpecificConfig,
31    chat_template: Option<String>,
32    tokenizer_json: Option<String>,
33    model_id: String,
34    jinja_explicit: Option<String>,
35    no_kv_cache: bool,
36    xlora_model_id: Option<String>,
37    xlora_order: Option<Ordering>,
38    tgt_non_granular_index: Option<usize>,
39    lora_adapter_ids: Option<Vec<String>>,
40    hf_cache_path: Option<PathBuf>,
41}
42
43impl AutoLoaderBuilder {
44    #[allow(clippy::too_many_arguments)]
45    pub fn new(
46        normal_cfg: NormalSpecificConfig,
47        vision_cfg: VisionSpecificConfig,
48        chat_template: Option<String>,
49        tokenizer_json: Option<String>,
50        model_id: String,
51        no_kv_cache: bool,
52        jinja_explicit: Option<String>,
53    ) -> Self {
54        Self {
55            normal_cfg,
56            vision_cfg,
57            chat_template,
58            tokenizer_json,
59            model_id,
60            jinja_explicit,
61            no_kv_cache,
62            xlora_model_id: None,
63            xlora_order: None,
64            tgt_non_granular_index: None,
65            lora_adapter_ids: None,
66            hf_cache_path: None,
67        }
68    }
69
70    pub fn with_xlora(
71        mut self,
72        model_id: String,
73        order: Ordering,
74        no_kv_cache: bool,
75        tgt_non_granular_index: Option<usize>,
76    ) -> Self {
77        self.xlora_model_id = Some(model_id);
78        self.xlora_order = Some(order);
79        self.no_kv_cache = no_kv_cache;
80        self.tgt_non_granular_index = tgt_non_granular_index;
81        self
82    }
83
84    pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
85        self.lora_adapter_ids = Some(adapters);
86        self
87    }
88
89    pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
90        self.hf_cache_path = Some(path);
91        self
92    }
93
94    pub fn build(self) -> Box<dyn Loader> {
95        let model_id = self.model_id.clone();
96        let mut normal_builder = NormalLoaderBuilder::new(
97            self.normal_cfg,
98            self.chat_template.clone(),
99            self.tokenizer_json.clone(),
100            Some(model_id.clone()),
101            self.no_kv_cache,
102            self.jinja_explicit.clone(),
103        );
104        if let (Some(id), Some(ord)) = (self.xlora_model_id.clone(), self.xlora_order.clone()) {
105            normal_builder =
106                normal_builder.with_xlora(id, ord, self.no_kv_cache, self.tgt_non_granular_index);
107        }
108        if let Some(ref adapters) = self.lora_adapter_ids {
109            normal_builder = normal_builder.with_lora(adapters.clone());
110        }
111        if let Some(ref path) = self.hf_cache_path {
112            normal_builder = normal_builder.hf_cache_path(path.clone());
113        }
114
115        let mut vision_builder = VisionLoaderBuilder::new(
116            self.vision_cfg,
117            self.chat_template,
118            self.tokenizer_json,
119            Some(model_id.clone()),
120            self.jinja_explicit,
121        );
122        if let Some(ref adapters) = self.lora_adapter_ids {
123            vision_builder = vision_builder.with_lora(adapters.clone());
124        }
125        if let Some(ref path) = self.hf_cache_path {
126            vision_builder = vision_builder.hf_cache_path(path.clone());
127        }
128
129        Box::new(AutoLoader {
130            model_id,
131            normal_builder: Mutex::new(Some(normal_builder)),
132            vision_builder: Mutex::new(Some(vision_builder)),
133            loader: Mutex::new(None),
134            hf_cache_path: self.hf_cache_path,
135        })
136    }
137}
138
139#[derive(Deserialize)]
140struct AutoConfig {
141    architectures: Vec<String>,
142}
143
144enum Detected {
145    Normal(NormalLoaderType),
146    Vision(VisionLoaderType),
147}
148
149impl AutoLoader {
150    fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<String> {
151        Ok(std::fs::read_to_string(paths.get_config_filename())?)
152    }
153
154    fn read_config_from_hf(
155        &self,
156        revision: Option<String>,
157        token_source: &TokenSource,
158        silent: bool,
159    ) -> Result<String> {
160        let cache = self
161            .hf_cache_path
162            .clone()
163            .map(Cache::new)
164            .unwrap_or_default();
165        let mut api = ApiBuilder::from_cache(cache)
166            .with_progress(!silent)
167            .with_token(get_token(token_source)?);
168        if let Ok(x) = std::env::var("HF_HUB_CACHE") {
169            api = api.with_cache_dir(x.into());
170        }
171        let api = api.build()?;
172        let revision = revision.unwrap_or_else(|| "main".to_string());
173        let api = api.repo(Repo::with_revision(
174            self.model_id.clone(),
175            RepoType::Model,
176            revision,
177        ));
178        let model_id = Path::new(&self.model_id);
179        let config_filename = api_get_file!(api, "config.json", model_id);
180        Ok(std::fs::read_to_string(config_filename)?)
181    }
182
183    fn detect(&self, config: &str) -> Result<Detected> {
184        let cfg: AutoConfig = serde_json::from_str(config)?;
185        if cfg.architectures.len() != 1 {
186            anyhow::bail!("Expected exactly one architecture in config");
187        }
188        let name = &cfg.architectures[0];
189        if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
190            return Ok(Detected::Vision(tp));
191        }
192        let tp = NormalLoaderType::from_causal_lm_name(name)?;
193        Ok(Detected::Normal(tp))
194    }
195
196    fn ensure_loader(&self, config: &str) -> Result<()> {
197        let mut guard = self.loader.lock().unwrap();
198        if guard.is_some() {
199            return Ok(());
200        }
201        match self.detect(config)? {
202            Detected::Normal(tp) => {
203                let builder = self
204                    .normal_builder
205                    .lock()
206                    .unwrap()
207                    .take()
208                    .expect("builder taken");
209                let loader = builder.build(Some(tp)).expect("build normal");
210                *guard = Some(loader);
211            }
212            Detected::Vision(tp) => {
213                let builder = self
214                    .vision_builder
215                    .lock()
216                    .unwrap()
217                    .take()
218                    .expect("builder taken");
219                let loader = builder.build(Some(tp));
220                *guard = Some(loader);
221            }
222        }
223        Ok(())
224    }
225}
226
227impl Loader for AutoLoader {
228    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
229    fn load_model_from_hf(
230        &self,
231        revision: Option<String>,
232        token_source: TokenSource,
233        dtype: &dyn TryIntoDType,
234        device: &Device,
235        silent: bool,
236        mapper: DeviceMapSetting,
237        in_situ_quant: Option<IsqType>,
238        paged_attn_config: Option<PagedAttentionConfig>,
239    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
240        let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
241        self.ensure_loader(&config)?;
242        self.loader
243            .lock()
244            .unwrap()
245            .as_ref()
246            .unwrap()
247            .load_model_from_hf(
248                revision,
249                token_source,
250                dtype,
251                device,
252                silent,
253                mapper,
254                in_situ_quant,
255                paged_attn_config,
256            )
257    }
258
259    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
260    fn load_model_from_path(
261        &self,
262        paths: &Box<dyn ModelPaths>,
263        dtype: &dyn TryIntoDType,
264        device: &Device,
265        silent: bool,
266        mapper: DeviceMapSetting,
267        in_situ_quant: Option<IsqType>,
268        paged_attn_config: Option<PagedAttentionConfig>,
269    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
270        let config = self.read_config_from_path(paths.as_ref())?;
271        self.ensure_loader(&config)?;
272        self.loader
273            .lock()
274            .unwrap()
275            .as_ref()
276            .unwrap()
277            .load_model_from_path(
278                paths,
279                dtype,
280                device,
281                silent,
282                mapper,
283                in_situ_quant,
284                paged_attn_config,
285            )
286    }
287
288    fn get_id(&self) -> String {
289        self.model_id.clone()
290    }
291
292    fn get_kind(&self) -> ModelKind {
293        self.loader
294            .lock()
295            .unwrap()
296            .as_ref()
297            .map(|l| l.get_kind())
298            .unwrap_or(ModelKind::Normal)
299    }
300}