1use super::{
2 EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingSpecificConfig, Loader, ModelKind,
3 ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, TokenSource,
4 VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
5};
6use crate::api_get_file;
7use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
8use crate::Ordering;
9use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
10use anyhow::Result;
11use candle_core::Device;
12use hf_hub::{
13 api::sync::{ApiBuilder, ApiRepo},
14 Cache, Repo, RepoType,
15};
16use serde::Deserialize;
17use std::path::Path;
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::sync::Mutex;
21use tracing::{debug, info};
22
23pub struct AutoLoader {
25 model_id: String,
26 normal_builder: Mutex<Option<NormalLoaderBuilder>>,
27 vision_builder: Mutex<Option<VisionLoaderBuilder>>,
28 embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
29 loader: Mutex<Option<Box<dyn Loader>>>,
30 hf_cache_path: Option<PathBuf>,
31}
32
33pub struct AutoLoaderBuilder {
34 normal_cfg: NormalSpecificConfig,
35 vision_cfg: VisionSpecificConfig,
36 embedding_cfg: EmbeddingSpecificConfig,
37 chat_template: Option<String>,
38 tokenizer_json: Option<String>,
39 model_id: String,
40 jinja_explicit: Option<String>,
41 no_kv_cache: bool,
42 xlora_model_id: Option<String>,
43 xlora_order: Option<Ordering>,
44 tgt_non_granular_index: Option<usize>,
45 lora_adapter_ids: Option<Vec<String>>,
46 hf_cache_path: Option<PathBuf>,
47}
48
49impl AutoLoaderBuilder {
50 #[allow(clippy::too_many_arguments)]
51 pub fn new(
52 normal_cfg: NormalSpecificConfig,
53 vision_cfg: VisionSpecificConfig,
54 embedding_cfg: EmbeddingSpecificConfig,
55 chat_template: Option<String>,
56 tokenizer_json: Option<String>,
57 model_id: String,
58 no_kv_cache: bool,
59 jinja_explicit: Option<String>,
60 ) -> Self {
61 Self {
62 normal_cfg,
63 vision_cfg,
64 embedding_cfg,
65 chat_template,
66 tokenizer_json,
67 model_id,
68 jinja_explicit,
69 no_kv_cache,
70 xlora_model_id: None,
71 xlora_order: None,
72 tgt_non_granular_index: None,
73 lora_adapter_ids: None,
74 hf_cache_path: None,
75 }
76 }
77
78 pub fn with_xlora(
79 mut self,
80 model_id: String,
81 order: Ordering,
82 no_kv_cache: bool,
83 tgt_non_granular_index: Option<usize>,
84 ) -> Self {
85 self.xlora_model_id = Some(model_id);
86 self.xlora_order = Some(order);
87 self.no_kv_cache = no_kv_cache;
88 self.tgt_non_granular_index = tgt_non_granular_index;
89 self
90 }
91
92 pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
93 self.lora_adapter_ids = Some(adapters);
94 self
95 }
96
97 pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
98 self.hf_cache_path = Some(path);
99 self
100 }
101
102 pub fn build(self) -> Box<dyn Loader> {
103 let Self {
104 normal_cfg,
105 vision_cfg,
106 embedding_cfg,
107 chat_template,
108 tokenizer_json,
109 model_id,
110 jinja_explicit,
111 no_kv_cache,
112 xlora_model_id,
113 xlora_order,
114 tgt_non_granular_index,
115 lora_adapter_ids,
116 hf_cache_path,
117 } = self;
118
119 let mut normal_builder = NormalLoaderBuilder::new(
120 normal_cfg,
121 chat_template.clone(),
122 tokenizer_json.clone(),
123 Some(model_id.clone()),
124 no_kv_cache,
125 jinja_explicit.clone(),
126 );
127 if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
128 normal_builder =
129 normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
130 }
131 if let Some(ref adapters) = lora_adapter_ids {
132 normal_builder = normal_builder.with_lora(adapters.clone());
133 }
134 if let Some(ref path) = hf_cache_path {
135 normal_builder = normal_builder.hf_cache_path(path.clone());
136 }
137
138 let mut vision_builder = VisionLoaderBuilder::new(
139 vision_cfg,
140 chat_template,
141 tokenizer_json.clone(),
142 Some(model_id.clone()),
143 jinja_explicit,
144 );
145 if let Some(ref adapters) = lora_adapter_ids {
146 vision_builder = vision_builder.with_lora(adapters.clone());
147 }
148 if let Some(ref path) = hf_cache_path {
149 vision_builder = vision_builder.hf_cache_path(path.clone());
150 }
151
152 let mut embedding_builder =
153 EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
154 if let Some(ref adapters) = lora_adapter_ids {
155 embedding_builder = embedding_builder.with_lora(adapters.clone());
156 }
157 if let Some(ref path) = hf_cache_path {
158 embedding_builder = embedding_builder.hf_cache_path(path.clone());
159 }
160
161 Box::new(AutoLoader {
162 model_id,
163 normal_builder: Mutex::new(Some(normal_builder)),
164 vision_builder: Mutex::new(Some(vision_builder)),
165 embedding_builder: Mutex::new(Some(embedding_builder)),
166 loader: Mutex::new(None),
167 hf_cache_path,
168 })
169 }
170}
171
172#[derive(Deserialize)]
173struct AutoConfig {
174 #[serde(default)]
175 architectures: Vec<String>,
176}
177
178struct ConfigArtifacts {
179 contents: String,
180 sentence_transformers_present: bool,
181}
182
183enum Detected {
184 Normal(NormalLoaderType),
185 Vision(VisionLoaderType),
186 Embedding(Option<EmbeddingLoaderType>),
187}
188
189impl AutoLoader {
190 fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
191 let config_path = paths.get_config_filename();
192 let contents = std::fs::read_to_string(config_path)?;
193 let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path);
194 Ok(ConfigArtifacts {
195 contents,
196 sentence_transformers_present,
197 })
198 }
199
200 fn read_config_from_hf(
201 &self,
202 revision: Option<String>,
203 token_source: &TokenSource,
204 silent: bool,
205 ) -> Result<ConfigArtifacts> {
206 let cache = self
207 .hf_cache_path
208 .clone()
209 .map(Cache::new)
210 .unwrap_or_default();
211 let mut api = ApiBuilder::from_cache(cache)
212 .with_progress(!silent)
213 .with_token(get_token(token_source)?);
214 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
215 api = api.with_cache_dir(x.into());
216 }
217 let api = api.build()?;
218 let revision = revision.unwrap_or_else(|| "main".to_string());
219 let api = api.repo(Repo::with_revision(
220 self.model_id.clone(),
221 RepoType::Model,
222 revision,
223 ));
224 let model_id = Path::new(&self.model_id);
225 let config_filename = api_get_file!(api, "config.json", model_id);
226 let contents = std::fs::read_to_string(&config_filename)?;
227 let sentence_transformers_present =
228 Self::has_sentence_transformers_sibling(&config_filename)
229 || Self::fetch_sentence_transformers_config(&api, model_id);
230 Ok(ConfigArtifacts {
231 contents,
232 sentence_transformers_present,
233 })
234 }
235
236 fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
237 config_path
238 .parent()
239 .map(|parent| parent.join("config_sentence_transformers.json").exists())
240 .unwrap_or(false)
241 }
242
243 fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path) -> bool {
244 if model_id.exists() {
245 return false;
246 }
247 match api.get("config_sentence_transformers.json") {
248 Ok(_) => true,
249 Err(err) => {
250 debug!(
251 "No `config_sentence_transformers.json` found for `{}`: {err}",
252 model_id.display()
253 );
254 false
255 }
256 }
257 }
258
259 fn detect(&self, config: &str, allow_embedding: bool) -> Result<Detected> {
260 let cfg: AutoConfig = serde_json::from_str(config)?;
261 if allow_embedding {
262 if let Some(name) = cfg.architectures.first() {
263 if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
264 info!(
265 "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
266 );
267 return Ok(Detected::Embedding(Some(tp)));
268 }
269 }
270 info!(
271 "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
272 );
273 return Ok(Detected::Embedding(None));
274 }
275 if cfg.architectures.len() != 1 {
276 anyhow::bail!("Expected exactly one architecture in config");
277 }
278 let name = &cfg.architectures[0];
279 if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
280 return Ok(Detected::Vision(tp));
281 }
282 let tp = NormalLoaderType::from_causal_lm_name(name)?;
283 Ok(Detected::Normal(tp))
284 }
285
286 fn ensure_loader(&self, config: &str, allow_embedding: bool) -> Result<()> {
287 let mut guard = self.loader.lock().unwrap();
288 if guard.is_some() {
289 return Ok(());
290 }
291 match self.detect(config, allow_embedding)? {
292 Detected::Normal(tp) => {
293 let builder = self
294 .normal_builder
295 .lock()
296 .unwrap()
297 .take()
298 .expect("builder taken");
299 let loader = builder.build(Some(tp)).expect("build normal");
300 *guard = Some(loader);
301 }
302 Detected::Vision(tp) => {
303 let builder = self
304 .vision_builder
305 .lock()
306 .unwrap()
307 .take()
308 .expect("builder taken");
309 let loader = builder.build(Some(tp));
310 *guard = Some(loader);
311 }
312 Detected::Embedding(tp) => {
313 let builder = self
314 .embedding_builder
315 .lock()
316 .unwrap()
317 .take()
318 .expect("builder taken");
319 let loader = builder.build(tp);
320 *guard = Some(loader);
321 }
322 }
323 Ok(())
324 }
325}
326
327impl Loader for AutoLoader {
328 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
329 fn load_model_from_hf(
330 &self,
331 revision: Option<String>,
332 token_source: TokenSource,
333 dtype: &dyn TryIntoDType,
334 device: &Device,
335 silent: bool,
336 mapper: DeviceMapSetting,
337 in_situ_quant: Option<IsqType>,
338 paged_attn_config: Option<PagedAttentionConfig>,
339 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
340 let _progress_guard = ProgressScopeGuard::new(silent);
341 let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
342 self.ensure_loader(&config.contents, config.sentence_transformers_present)?;
343 self.loader
344 .lock()
345 .unwrap()
346 .as_ref()
347 .unwrap()
348 .load_model_from_hf(
349 revision,
350 token_source,
351 dtype,
352 device,
353 silent,
354 mapper,
355 in_situ_quant,
356 paged_attn_config,
357 )
358 }
359
360 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
361 fn load_model_from_path(
362 &self,
363 paths: &Box<dyn ModelPaths>,
364 dtype: &dyn TryIntoDType,
365 device: &Device,
366 silent: bool,
367 mapper: DeviceMapSetting,
368 in_situ_quant: Option<IsqType>,
369 paged_attn_config: Option<PagedAttentionConfig>,
370 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
371 let _progress_guard = ProgressScopeGuard::new(silent);
372 let config = self.read_config_from_path(paths.as_ref())?;
373 self.ensure_loader(&config.contents, config.sentence_transformers_present)?;
374 self.loader
375 .lock()
376 .unwrap()
377 .as_ref()
378 .unwrap()
379 .load_model_from_path(
380 paths,
381 dtype,
382 device,
383 silent,
384 mapper,
385 in_situ_quant,
386 paged_attn_config,
387 )
388 }
389
390 fn get_id(&self) -> String {
391 self.model_id.clone()
392 }
393
394 fn get_kind(&self) -> ModelKind {
395 self.loader
396 .lock()
397 .unwrap()
398 .as_ref()
399 .map(|l| l.get_kind())
400 .unwrap_or(ModelKind::Normal)
401 }
402}