1use super::{
2 Loader, ModelKind, ModelPaths, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig,
3 TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
4};
5use crate::api_get_file;
6use crate::utils::tokens::get_token;
7use crate::Ordering;
8use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
9use anyhow::Result;
10use candle_core::Device;
11use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
12use serde::Deserialize;
13use std::path::Path;
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::sync::Mutex;
17use tracing::info;
18
19pub struct AutoLoader {
21 model_id: String,
22 normal_builder: Mutex<Option<NormalLoaderBuilder>>,
23 vision_builder: Mutex<Option<VisionLoaderBuilder>>,
24 loader: Mutex<Option<Box<dyn Loader>>>,
25 hf_cache_path: Option<PathBuf>,
26}
27
28pub struct AutoLoaderBuilder {
29 normal_cfg: NormalSpecificConfig,
30 vision_cfg: VisionSpecificConfig,
31 chat_template: Option<String>,
32 tokenizer_json: Option<String>,
33 model_id: String,
34 jinja_explicit: Option<String>,
35 no_kv_cache: bool,
36 xlora_model_id: Option<String>,
37 xlora_order: Option<Ordering>,
38 tgt_non_granular_index: Option<usize>,
39 lora_adapter_ids: Option<Vec<String>>,
40 hf_cache_path: Option<PathBuf>,
41}
42
43impl AutoLoaderBuilder {
44 #[allow(clippy::too_many_arguments)]
45 pub fn new(
46 normal_cfg: NormalSpecificConfig,
47 vision_cfg: VisionSpecificConfig,
48 chat_template: Option<String>,
49 tokenizer_json: Option<String>,
50 model_id: String,
51 no_kv_cache: bool,
52 jinja_explicit: Option<String>,
53 ) -> Self {
54 Self {
55 normal_cfg,
56 vision_cfg,
57 chat_template,
58 tokenizer_json,
59 model_id,
60 jinja_explicit,
61 no_kv_cache,
62 xlora_model_id: None,
63 xlora_order: None,
64 tgt_non_granular_index: None,
65 lora_adapter_ids: None,
66 hf_cache_path: None,
67 }
68 }
69
70 pub fn with_xlora(
71 mut self,
72 model_id: String,
73 order: Ordering,
74 no_kv_cache: bool,
75 tgt_non_granular_index: Option<usize>,
76 ) -> Self {
77 self.xlora_model_id = Some(model_id);
78 self.xlora_order = Some(order);
79 self.no_kv_cache = no_kv_cache;
80 self.tgt_non_granular_index = tgt_non_granular_index;
81 self
82 }
83
84 pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
85 self.lora_adapter_ids = Some(adapters);
86 self
87 }
88
89 pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
90 self.hf_cache_path = Some(path);
91 self
92 }
93
94 pub fn build(self) -> Box<dyn Loader> {
95 let model_id = self.model_id.clone();
96 let mut normal_builder = NormalLoaderBuilder::new(
97 self.normal_cfg,
98 self.chat_template.clone(),
99 self.tokenizer_json.clone(),
100 Some(model_id.clone()),
101 self.no_kv_cache,
102 self.jinja_explicit.clone(),
103 );
104 if let (Some(id), Some(ord)) = (self.xlora_model_id.clone(), self.xlora_order.clone()) {
105 normal_builder =
106 normal_builder.with_xlora(id, ord, self.no_kv_cache, self.tgt_non_granular_index);
107 }
108 if let Some(ref adapters) = self.lora_adapter_ids {
109 normal_builder = normal_builder.with_lora(adapters.clone());
110 }
111 if let Some(ref path) = self.hf_cache_path {
112 normal_builder = normal_builder.hf_cache_path(path.clone());
113 }
114
115 let mut vision_builder = VisionLoaderBuilder::new(
116 self.vision_cfg,
117 self.chat_template,
118 self.tokenizer_json,
119 Some(model_id.clone()),
120 self.jinja_explicit,
121 );
122 if let Some(ref adapters) = self.lora_adapter_ids {
123 vision_builder = vision_builder.with_lora(adapters.clone());
124 }
125 if let Some(ref path) = self.hf_cache_path {
126 vision_builder = vision_builder.hf_cache_path(path.clone());
127 }
128
129 Box::new(AutoLoader {
130 model_id,
131 normal_builder: Mutex::new(Some(normal_builder)),
132 vision_builder: Mutex::new(Some(vision_builder)),
133 loader: Mutex::new(None),
134 hf_cache_path: self.hf_cache_path,
135 })
136 }
137}
138
139#[derive(Deserialize)]
140struct AutoConfig {
141 architectures: Vec<String>,
142}
143
144enum Detected {
145 Normal(NormalLoaderType),
146 Vision(VisionLoaderType),
147}
148
149impl AutoLoader {
150 fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<String> {
151 Ok(std::fs::read_to_string(paths.get_config_filename())?)
152 }
153
154 fn read_config_from_hf(
155 &self,
156 revision: Option<String>,
157 token_source: &TokenSource,
158 silent: bool,
159 ) -> Result<String> {
160 let cache = self
161 .hf_cache_path
162 .clone()
163 .map(Cache::new)
164 .unwrap_or_default();
165 let mut api = ApiBuilder::from_cache(cache)
166 .with_progress(!silent)
167 .with_token(get_token(token_source)?);
168 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
169 api = api.with_cache_dir(x.into());
170 }
171 let api = api.build()?;
172 let revision = revision.unwrap_or_else(|| "main".to_string());
173 let api = api.repo(Repo::with_revision(
174 self.model_id.clone(),
175 RepoType::Model,
176 revision,
177 ));
178 let model_id = Path::new(&self.model_id);
179 let config_filename = api_get_file!(api, "config.json", model_id);
180 Ok(std::fs::read_to_string(config_filename)?)
181 }
182
183 fn detect(&self, config: &str) -> Result<Detected> {
184 let cfg: AutoConfig = serde_json::from_str(config)?;
185 if cfg.architectures.len() != 1 {
186 anyhow::bail!("Expected exactly one architecture in config");
187 }
188 let name = &cfg.architectures[0];
189 if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
190 return Ok(Detected::Vision(tp));
191 }
192 let tp = NormalLoaderType::from_causal_lm_name(name)?;
193 Ok(Detected::Normal(tp))
194 }
195
196 fn ensure_loader(&self, config: &str) -> Result<()> {
197 let mut guard = self.loader.lock().unwrap();
198 if guard.is_some() {
199 return Ok(());
200 }
201 match self.detect(config)? {
202 Detected::Normal(tp) => {
203 let builder = self
204 .normal_builder
205 .lock()
206 .unwrap()
207 .take()
208 .expect("builder taken");
209 let loader = builder.build(Some(tp)).expect("build normal");
210 *guard = Some(loader);
211 }
212 Detected::Vision(tp) => {
213 let builder = self
214 .vision_builder
215 .lock()
216 .unwrap()
217 .take()
218 .expect("builder taken");
219 let loader = builder.build(Some(tp));
220 *guard = Some(loader);
221 }
222 }
223 Ok(())
224 }
225}
226
227impl Loader for AutoLoader {
228 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
229 fn load_model_from_hf(
230 &self,
231 revision: Option<String>,
232 token_source: TokenSource,
233 dtype: &dyn TryIntoDType,
234 device: &Device,
235 silent: bool,
236 mapper: DeviceMapSetting,
237 in_situ_quant: Option<IsqType>,
238 paged_attn_config: Option<PagedAttentionConfig>,
239 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
240 let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
241 self.ensure_loader(&config)?;
242 self.loader
243 .lock()
244 .unwrap()
245 .as_ref()
246 .unwrap()
247 .load_model_from_hf(
248 revision,
249 token_source,
250 dtype,
251 device,
252 silent,
253 mapper,
254 in_situ_quant,
255 paged_attn_config,
256 )
257 }
258
259 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
260 fn load_model_from_path(
261 &self,
262 paths: &Box<dyn ModelPaths>,
263 dtype: &dyn TryIntoDType,
264 device: &Device,
265 silent: bool,
266 mapper: DeviceMapSetting,
267 in_situ_quant: Option<IsqType>,
268 paged_attn_config: Option<PagedAttentionConfig>,
269 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
270 let config = self.read_config_from_path(paths.as_ref())?;
271 self.ensure_loader(&config)?;
272 self.loader
273 .lock()
274 .unwrap()
275 .as_ref()
276 .unwrap()
277 .load_model_from_path(
278 paths,
279 dtype,
280 device,
281 silent,
282 mapper,
283 in_situ_quant,
284 paged_attn_config,
285 )
286 }
287
288 fn get_id(&self) -> String {
289 self.model_id.clone()
290 }
291
292 fn get_kind(&self) -> ModelKind {
293 self.loader
294 .lock()
295 .unwrap()
296 .as_ref()
297 .map(|l| l.get_kind())
298 .unwrap_or(ModelKind::Normal)
299 }
300}