1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod normal_loaders;
4mod vision_loaders;
5pub use auto_device_map::AutoDeviceMapParams;
6use auto_device_map::NonMappedSubModel;
7
8use std::{
9 fmt::{self, Debug},
10 path::PathBuf,
11 str::FromStr,
12 sync::Arc,
13};
14
15use anyhow::Result;
16use as_any::AsAny;
17use candle_core::{DType, Device};
18use mistralrs_quant::{IsqType, QuantizedConfig};
19use serde::Deserialize;
20use tokio::sync::Mutex;
21
22pub use normal_loaders::{
23 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
24 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
25 NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
26 Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
27};
28
29pub use vision_loaders::{
30 AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
31 LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
32 Qwen2_5VLLoader, Qwen3VLLoader, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
33 VisionModelLoader,
34};
35
36pub use diffusion_loaders::{
37 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
38 DiffusionModelPathsInner, FluxLoader,
39};
40
41use crate::{
42 matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
43 DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
44};
45
46use super::{paths::AdapterPaths, Pipeline};
47
48pub trait ModelPaths: AsAny + Debug + Send + Sync {
51 fn get_weight_filenames(&self) -> &[PathBuf];
53
54 fn get_config_filename(&self) -> &PathBuf;
58
59 fn get_tokenizer_filename(&self) -> &PathBuf;
63
64 fn get_template_filename(&self) -> &Option<PathBuf>;
68
69 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
71
72 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
74
75 fn get_processor_config(&self) -> &Option<PathBuf>;
77
78 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
80
81 fn get_adapter_paths(&self) -> &AdapterPaths;
83}
84
85#[derive(Clone, Debug)]
86pub struct LocalModelPaths<P: Debug> {
88 pub tokenizer_filename: P,
89 pub config_filename: P,
90 pub template_filename: Option<P>,
91 pub filenames: Vec<P>,
92 pub adapter_paths: AdapterPaths,
93 pub gen_conf: Option<P>,
94 pub preprocessor_config: Option<P>,
95 pub processor_config: Option<P>,
96 pub chat_template_json_filename: Option<P>,
97}
98
99impl<P: Debug> LocalModelPaths<P> {
100 #[allow(clippy::too_many_arguments)]
101 pub fn new(
102 tokenizer_filename: P,
103 config_filename: P,
104 template_filename: P,
105 filenames: Vec<P>,
106 adapter_paths: AdapterPaths,
107 gen_conf: Option<P>,
108 preprocessor_config: Option<P>,
109 processor_config: Option<P>,
110 chat_template_json_filename: Option<P>,
111 ) -> Self {
112 Self {
113 tokenizer_filename,
114 config_filename,
115 template_filename: Some(template_filename),
116 filenames,
117 adapter_paths,
118 gen_conf,
119 preprocessor_config,
120 processor_config,
121 chat_template_json_filename,
122 }
123 }
124}
125
126impl ModelPaths for LocalModelPaths<PathBuf> {
127 fn get_config_filename(&self) -> &PathBuf {
128 &self.config_filename
129 }
130 fn get_tokenizer_filename(&self) -> &PathBuf {
131 &self.tokenizer_filename
132 }
133 fn get_weight_filenames(&self) -> &[PathBuf] {
134 &self.filenames
135 }
136 fn get_template_filename(&self) -> &Option<PathBuf> {
137 &self.template_filename
138 }
139 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
140 self.gen_conf.as_ref()
141 }
142 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
143 &self.preprocessor_config
144 }
145 fn get_processor_config(&self) -> &Option<PathBuf> {
146 &self.processor_config
147 }
148 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
149 &self.chat_template_json_filename
150 }
151 fn get_adapter_paths(&self) -> &AdapterPaths {
152 &self.adapter_paths
153 }
154}
155
156#[derive(Debug, Clone)]
157pub enum TokenSource {
159 Literal(String),
160 EnvVar(String),
161 Path(String),
162 CacheToken,
163 None,
164}
165
166impl FromStr for TokenSource {
167 type Err = String;
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 let parts: Vec<&str> = s.splitn(2, ':').collect();
171 match parts[0] {
172 "literal" => parts
173 .get(1)
174 .map(|&value| TokenSource::Literal(value.to_string()))
175 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
176 "env" => Ok(TokenSource::EnvVar(
177 parts
178 .get(1)
179 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
180 .to_string(),
181 )),
182 "path" => parts
183 .get(1)
184 .map(|&value| TokenSource::Path(value.to_string()))
185 .ok_or_else(|| "Expected a value for 'path'".to_string()),
186 "cache" => Ok(TokenSource::CacheToken),
187 "none" => Ok(TokenSource::None),
188 _ => Err("Invalid token source format".to_string()),
189 }
190 }
191}
192
193impl fmt::Display for TokenSource {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 match self {
196 TokenSource::Literal(value) => write!(f, "literal:{value}"),
197 TokenSource::EnvVar(value) => write!(f, "env:{value}"),
198 TokenSource::Path(value) => write!(f, "path:{value}"),
199 TokenSource::CacheToken => write!(f, "cache"),
200 TokenSource::None => write!(f, "none"),
201 }
202 }
203}
204
205#[derive(Clone, Default, derive_more::From, strum::Display)]
207pub enum ModelKind {
208 #[default]
209 #[strum(to_string = "normal (no adapters)")]
210 Normal,
211
212 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
213 GgufQuantized { quant: QuantizationKind },
214
215 #[strum(to_string = "{adapter}")]
216 Adapter { adapter: AdapterKind },
217
218 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
219 GgufAdapter {
220 adapter: AdapterKind,
221 quant: QuantizationKind,
222 },
223
224 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
225 Speculative {
226 target: Box<ModelKind>,
227 draft: Box<ModelKind>,
228 },
229
230 #[strum(to_string = "anymoe: target: `{target}`")]
231 AnyMoe { target: Box<ModelKind> },
232}
233
234#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
235#[strum(serialize_all = "kebab-case")]
236pub enum QuantizationKind {
237 Ggml,
239 Gguf,
241 Gptq,
243}
244
245#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
246#[strum(serialize_all = "kebab-case")]
247pub enum AdapterKind {
248 Lora,
250 XLora,
252}
253
254pub trait PrettyName: strum::EnumMessage + ToString {
256 fn pretty_name(&self) -> String {
257 match self.get_documentation() {
258 Some(s) => s.to_string(),
259 None => self.to_string(),
262 }
263 }
264}
265
266impl PrettyName for AdapterKind {}
267impl PrettyName for QuantizationKind {}
268
269impl ModelKind {
270 pub fn is_quantized(&self) -> bool {
272 self.quantized_kind().iter().any(|q| q.is_some())
273 }
274
275 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
276 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
277 }
278
279 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
280 use ModelKind::*;
281
282 match self {
283 Normal | Adapter { .. } => vec![None],
284 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
285 Speculative { target, draft } => {
286 let t = *target.clone();
287 let d = *draft.clone();
288
289 [t.quantized_kind(), d.quantized_kind()].concat()
290 }
291 AnyMoe { target } => target.quantized_kind(),
292 }
293 }
294
295 pub fn is_adapted(&self) -> bool {
297 self.adapted_kind().iter().any(|a| a.is_some())
298 }
299
300 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
301 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
302 }
303
304 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
305 use ModelKind::*;
306
307 match self {
308 Normal | GgufQuantized { .. } => vec![None],
309 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
310 Speculative { target, draft } => {
311 let t = *target.clone();
312 let d = *draft.clone();
313
314 [t.adapted_kind(), d.adapted_kind()].concat()
315 }
316 AnyMoe { target } => target.adapted_kind(),
317 }
318 }
319}
320
321#[derive(Deserialize)]
322pub struct QuantizationConfigShim {
323 quantization_config: Option<QuantizedConfig>,
324}
325
326impl QuantizationConfigShim {
327 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
328 let QuantizationConfigShim {
329 quantization_config,
330 } = serde_json::from_str(config)?;
331
332 if let Some(quantization_config) = quantization_config {
333 Ok(quantization_config.pack_factor(dtype))
334 } else {
335 Ok(1)
336 }
337 }
338}
339
340pub trait DeviceMappedModelLoader {
341 fn non_mapped_max_act_size_elems(
344 &self,
345 config: &str,
346 params: &AutoDeviceMapParams,
347 ) -> Result<usize>;
348 fn mapped_max_act_size_elems(
350 &self,
351 config: &str,
352 params: &AutoDeviceMapParams,
353 ) -> Result<usize>;
354 fn non_mapped_size_in_bytes(
356 &self,
357 config: &str,
358 dtype: DType,
359 weight_pack_factor: usize,
360 matformer_config: Option<&MatformerSliceConfig>,
361 ) -> Result<usize>;
362 fn layer_sizes_in_bytes(
364 &self,
365 config: &str,
366 dtype: DType,
367 weight_pack_factor: usize,
368 matformer_config: Option<&MatformerSliceConfig>,
369 ) -> Result<Vec<usize>>;
370 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
371 None
372 }
373 fn num_layers(&self, config: &str) -> Result<usize>;
374 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
375
376 #[allow(clippy::too_many_arguments)]
377 fn get_device_layers(
378 &self,
379 config: &str,
380 num_layers: usize,
381 layer_sizes_in_bytes: Vec<usize>,
382 non_mapped_size_in_bytes: usize,
383 total_model_size_in_bytes: usize,
384 devices: &[Device],
385 dtype: DType,
386 params: &AutoDeviceMapParams,
387 paged_attn_config: Option<&PagedAttentionConfig>,
388 ) -> Result<DeviceMapMetadata>
389 where
390 Self: Sized,
391 {
392 auto_device_map::get_device_layers(
393 self,
394 config,
395 num_layers,
396 layer_sizes_in_bytes,
397 non_mapped_size_in_bytes,
398 total_model_size_in_bytes,
399 devices,
400 dtype,
401 params,
402 paged_attn_config,
403 )
404 }
405}
406
407pub trait Loader: Send + Sync {
428 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
432 fn load_model_from_hf(
433 &self,
434 revision: Option<String>,
435 token_source: TokenSource,
436 dtype: &dyn TryIntoDType,
437 device: &Device,
438 silent: bool,
439 mapper: DeviceMapSetting,
440 in_situ_quant: Option<IsqType>,
441 paged_attn_config: Option<PagedAttentionConfig>,
442 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
443
444 #[allow(
447 clippy::type_complexity,
448 clippy::too_many_arguments,
449 clippy::borrowed_box
450 )]
451 fn load_model_from_path(
452 &self,
453 paths: &Box<dyn ModelPaths>,
454 dtype: &dyn TryIntoDType,
455 device: &Device,
456 silent: bool,
457 mapper: DeviceMapSetting,
458 in_situ_quant: Option<IsqType>,
459 paged_attn_config: Option<PagedAttentionConfig>,
460 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
461
462 fn get_id(&self) -> String;
463 fn get_kind(&self) -> ModelKind;
464}