1use super::{
2 EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingSpecificConfig, Loader, ModelKind,
3 ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, TokenSource,
4 VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
5};
6use crate::api_get_file;
7use crate::utils::tokens::get_token;
8use crate::Ordering;
9use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
10use anyhow::Result;
11use candle_core::Device;
12use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
13use serde::Deserialize;
14use std::path::Path;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::Mutex;
18use tracing::info;
19
20pub struct AutoLoader {
22 model_id: String,
23 normal_builder: Mutex<Option<NormalLoaderBuilder>>,
24 vision_builder: Mutex<Option<VisionLoaderBuilder>>,
25 embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
26 loader: Mutex<Option<Box<dyn Loader>>>,
27 hf_cache_path: Option<PathBuf>,
28}
29
30pub struct AutoLoaderBuilder {
31 normal_cfg: NormalSpecificConfig,
32 vision_cfg: VisionSpecificConfig,
33 embedding_cfg: EmbeddingSpecificConfig,
34 chat_template: Option<String>,
35 tokenizer_json: Option<String>,
36 model_id: String,
37 jinja_explicit: Option<String>,
38 no_kv_cache: bool,
39 xlora_model_id: Option<String>,
40 xlora_order: Option<Ordering>,
41 tgt_non_granular_index: Option<usize>,
42 lora_adapter_ids: Option<Vec<String>>,
43 hf_cache_path: Option<PathBuf>,
44}
45
46impl AutoLoaderBuilder {
47 #[allow(clippy::too_many_arguments)]
48 pub fn new(
49 normal_cfg: NormalSpecificConfig,
50 vision_cfg: VisionSpecificConfig,
51 embedding_cfg: EmbeddingSpecificConfig,
52 chat_template: Option<String>,
53 tokenizer_json: Option<String>,
54 model_id: String,
55 no_kv_cache: bool,
56 jinja_explicit: Option<String>,
57 ) -> Self {
58 Self {
59 normal_cfg,
60 vision_cfg,
61 embedding_cfg,
62 chat_template,
63 tokenizer_json,
64 model_id,
65 jinja_explicit,
66 no_kv_cache,
67 xlora_model_id: None,
68 xlora_order: None,
69 tgt_non_granular_index: None,
70 lora_adapter_ids: None,
71 hf_cache_path: None,
72 }
73 }
74
75 pub fn with_xlora(
76 mut self,
77 model_id: String,
78 order: Ordering,
79 no_kv_cache: bool,
80 tgt_non_granular_index: Option<usize>,
81 ) -> Self {
82 self.xlora_model_id = Some(model_id);
83 self.xlora_order = Some(order);
84 self.no_kv_cache = no_kv_cache;
85 self.tgt_non_granular_index = tgt_non_granular_index;
86 self
87 }
88
89 pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
90 self.lora_adapter_ids = Some(adapters);
91 self
92 }
93
94 pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
95 self.hf_cache_path = Some(path);
96 self
97 }
98
99 pub fn build(self) -> Box<dyn Loader> {
100 let Self {
101 normal_cfg,
102 vision_cfg,
103 embedding_cfg,
104 chat_template,
105 tokenizer_json,
106 model_id,
107 jinja_explicit,
108 no_kv_cache,
109 xlora_model_id,
110 xlora_order,
111 tgt_non_granular_index,
112 lora_adapter_ids,
113 hf_cache_path,
114 } = self;
115
116 let mut normal_builder = NormalLoaderBuilder::new(
117 normal_cfg,
118 chat_template.clone(),
119 tokenizer_json.clone(),
120 Some(model_id.clone()),
121 no_kv_cache,
122 jinja_explicit.clone(),
123 );
124 if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
125 normal_builder =
126 normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
127 }
128 if let Some(ref adapters) = lora_adapter_ids {
129 normal_builder = normal_builder.with_lora(adapters.clone());
130 }
131 if let Some(ref path) = hf_cache_path {
132 normal_builder = normal_builder.hf_cache_path(path.clone());
133 }
134
135 let mut vision_builder = VisionLoaderBuilder::new(
136 vision_cfg,
137 chat_template,
138 tokenizer_json.clone(),
139 Some(model_id.clone()),
140 jinja_explicit,
141 );
142 if let Some(ref adapters) = lora_adapter_ids {
143 vision_builder = vision_builder.with_lora(adapters.clone());
144 }
145 if let Some(ref path) = hf_cache_path {
146 vision_builder = vision_builder.hf_cache_path(path.clone());
147 }
148
149 let mut embedding_builder =
150 EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
151 if let Some(ref adapters) = lora_adapter_ids {
152 embedding_builder = embedding_builder.with_lora(adapters.clone());
153 }
154 if let Some(ref path) = hf_cache_path {
155 embedding_builder = embedding_builder.hf_cache_path(path.clone());
156 }
157
158 Box::new(AutoLoader {
159 model_id,
160 normal_builder: Mutex::new(Some(normal_builder)),
161 vision_builder: Mutex::new(Some(vision_builder)),
162 embedding_builder: Mutex::new(Some(embedding_builder)),
163 loader: Mutex::new(None),
164 hf_cache_path,
165 })
166 }
167}
168
169#[derive(Deserialize)]
170struct AutoConfig {
171 architectures: Vec<String>,
172}
173
174enum Detected {
175 Normal(NormalLoaderType),
176 Vision(VisionLoaderType),
177 Embedding(EmbeddingLoaderType),
178}
179
180impl AutoLoader {
181 fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<String> {
182 Ok(std::fs::read_to_string(paths.get_config_filename())?)
183 }
184
185 fn read_config_from_hf(
186 &self,
187 revision: Option<String>,
188 token_source: &TokenSource,
189 silent: bool,
190 ) -> Result<String> {
191 let cache = self
192 .hf_cache_path
193 .clone()
194 .map(Cache::new)
195 .unwrap_or_default();
196 let mut api = ApiBuilder::from_cache(cache)
197 .with_progress(!silent)
198 .with_token(get_token(token_source)?);
199 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
200 api = api.with_cache_dir(x.into());
201 }
202 let api = api.build()?;
203 let revision = revision.unwrap_or_else(|| "main".to_string());
204 let api = api.repo(Repo::with_revision(
205 self.model_id.clone(),
206 RepoType::Model,
207 revision,
208 ));
209 let model_id = Path::new(&self.model_id);
210 let config_filename = api_get_file!(api, "config.json", model_id);
211 Ok(std::fs::read_to_string(config_filename)?)
212 }
213
214 fn detect(&self, config: &str) -> Result<Detected> {
215 let cfg: AutoConfig = serde_json::from_str(config)?;
216 if cfg.architectures.len() != 1 {
217 anyhow::bail!("Expected exactly one architecture in config");
218 }
219 let name = &cfg.architectures[0];
220 if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
221 return Ok(Detected::Vision(tp));
222 }
223 if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
224 return Ok(Detected::Embedding(tp));
225 }
226 let tp = NormalLoaderType::from_causal_lm_name(name)?;
227 Ok(Detected::Normal(tp))
228 }
229
230 fn ensure_loader(&self, config: &str) -> Result<()> {
231 let mut guard = self.loader.lock().unwrap();
232 if guard.is_some() {
233 return Ok(());
234 }
235 match self.detect(config)? {
236 Detected::Normal(tp) => {
237 let builder = self
238 .normal_builder
239 .lock()
240 .unwrap()
241 .take()
242 .expect("builder taken");
243 let loader = builder.build(Some(tp)).expect("build normal");
244 *guard = Some(loader);
245 }
246 Detected::Vision(tp) => {
247 let builder = self
248 .vision_builder
249 .lock()
250 .unwrap()
251 .take()
252 .expect("builder taken");
253 let loader = builder.build(Some(tp));
254 *guard = Some(loader);
255 }
256 Detected::Embedding(tp) => {
257 let builder = self
258 .embedding_builder
259 .lock()
260 .unwrap()
261 .take()
262 .expect("builder taken");
263 let loader = builder.build(Some(tp));
264 *guard = Some(loader);
265 }
266 }
267 Ok(())
268 }
269}
270
271impl Loader for AutoLoader {
272 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
273 fn load_model_from_hf(
274 &self,
275 revision: Option<String>,
276 token_source: TokenSource,
277 dtype: &dyn TryIntoDType,
278 device: &Device,
279 silent: bool,
280 mapper: DeviceMapSetting,
281 in_situ_quant: Option<IsqType>,
282 paged_attn_config: Option<PagedAttentionConfig>,
283 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
284 let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
285 self.ensure_loader(&config)?;
286 self.loader
287 .lock()
288 .unwrap()
289 .as_ref()
290 .unwrap()
291 .load_model_from_hf(
292 revision,
293 token_source,
294 dtype,
295 device,
296 silent,
297 mapper,
298 in_situ_quant,
299 paged_attn_config,
300 )
301 }
302
303 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
304 fn load_model_from_path(
305 &self,
306 paths: &Box<dyn ModelPaths>,
307 dtype: &dyn TryIntoDType,
308 device: &Device,
309 silent: bool,
310 mapper: DeviceMapSetting,
311 in_situ_quant: Option<IsqType>,
312 paged_attn_config: Option<PagedAttentionConfig>,
313 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
314 let config = self.read_config_from_path(paths.as_ref())?;
315 self.ensure_loader(&config)?;
316 self.loader
317 .lock()
318 .unwrap()
319 .as_ref()
320 .unwrap()
321 .load_model_from_path(
322 paths,
323 dtype,
324 device,
325 silent,
326 mapper,
327 in_situ_quant,
328 paged_attn_config,
329 )
330 }
331
332 fn get_id(&self) -> String {
333 self.model_id.clone()
334 }
335
336 fn get_kind(&self) -> ModelKind {
337 self.loader
338 .lock()
339 .unwrap()
340 .as_ref()
341 .map(|l| l.get_kind())
342 .unwrap_or(ModelKind::Normal)
343 }
344}