1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod normal_loaders;
5mod vision_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10 fmt::{self, Debug},
11 path::PathBuf,
12 str::FromStr,
13 sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
25 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
26 NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
27 Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
28};
29
30pub use vision_loaders::{
31 AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
32 LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
33 Qwen2_5VLLoader, Qwen3VLLoader, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
34 VisionModelLoader,
35};
36
37pub use embedding_loaders::{
38 AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
39 EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
40 Qwen3EmbeddingLoader,
41};
42
43pub use diffusion_loaders::{
44 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
45 DiffusionModelPathsInner, FluxLoader,
46};
47
48use crate::{
49 matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
50 DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
51};
52
53use super::{paths::AdapterPaths, Pipeline};
54
55pub trait ModelPaths: AsAny + Debug + Send + Sync {
58 fn get_weight_filenames(&self) -> &[PathBuf];
60
61 fn get_config_filename(&self) -> &PathBuf;
65
66 fn get_tokenizer_filename(&self) -> &PathBuf;
70
71 fn get_template_filename(&self) -> &Option<PathBuf>;
75
76 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
78
79 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
81
82 fn get_processor_config(&self) -> &Option<PathBuf>;
84
85 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
87
88 fn get_adapter_paths(&self) -> &AdapterPaths;
90
91 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
93}
94
95#[derive(Clone, Debug)]
96pub struct LocalModelPaths<P: Debug> {
98 pub tokenizer_filename: P,
99 pub config_filename: P,
100 pub template_filename: Option<P>,
101 pub filenames: Vec<P>,
102 pub adapter_paths: AdapterPaths,
103 pub gen_conf: Option<P>,
104 pub preprocessor_config: Option<P>,
105 pub processor_config: Option<P>,
106 pub chat_template_json_filename: Option<P>,
107}
108
109impl<P: Debug> LocalModelPaths<P> {
110 #[allow(clippy::too_many_arguments)]
111 pub fn new(
112 tokenizer_filename: P,
113 config_filename: P,
114 template_filename: P,
115 filenames: Vec<P>,
116 adapter_paths: AdapterPaths,
117 gen_conf: Option<P>,
118 preprocessor_config: Option<P>,
119 processor_config: Option<P>,
120 chat_template_json_filename: Option<P>,
121 ) -> Self {
122 Self {
123 tokenizer_filename,
124 config_filename,
125 template_filename: Some(template_filename),
126 filenames,
127 adapter_paths,
128 gen_conf,
129 preprocessor_config,
130 processor_config,
131 chat_template_json_filename,
132 }
133 }
134}
135
136impl ModelPaths for LocalModelPaths<PathBuf> {
137 fn get_config_filename(&self) -> &PathBuf {
138 &self.config_filename
139 }
140 fn get_tokenizer_filename(&self) -> &PathBuf {
141 &self.tokenizer_filename
142 }
143 fn get_weight_filenames(&self) -> &[PathBuf] {
144 &self.filenames
145 }
146 fn get_template_filename(&self) -> &Option<PathBuf> {
147 &self.template_filename
148 }
149 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
150 self.gen_conf.as_ref()
151 }
152 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
153 &self.preprocessor_config
154 }
155 fn get_processor_config(&self) -> &Option<PathBuf> {
156 &self.processor_config
157 }
158 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
159 &self.chat_template_json_filename
160 }
161 fn get_adapter_paths(&self) -> &AdapterPaths {
162 &self.adapter_paths
163 }
164 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
165 None
166 }
167}
168
169#[derive(Clone, Debug)]
170pub struct EmbeddingModelPaths<P: Debug> {
172 pub tokenizer_filename: P,
173 pub config_filename: P,
174 pub modules: Vec<EmbeddingModulePaths>,
175 pub filenames: Vec<P>,
176 pub adapter_paths: AdapterPaths,
177}
178
179impl<P: Debug> EmbeddingModelPaths<P> {
180 #[allow(clippy::too_many_arguments)]
181 pub fn new(
182 tokenizer_filename: P,
183 config_filename: P,
184 filenames: Vec<P>,
185 adapter_paths: AdapterPaths,
186 modules: Vec<EmbeddingModulePaths>,
187 ) -> Self {
188 Self {
189 tokenizer_filename,
190 config_filename,
191 filenames,
192 adapter_paths,
193 modules,
194 }
195 }
196}
197
198impl ModelPaths for EmbeddingModelPaths<PathBuf> {
199 fn get_config_filename(&self) -> &PathBuf {
200 &self.config_filename
201 }
202 fn get_tokenizer_filename(&self) -> &PathBuf {
203 &self.tokenizer_filename
204 }
205 fn get_weight_filenames(&self) -> &[PathBuf] {
206 &self.filenames
207 }
208 fn get_template_filename(&self) -> &Option<PathBuf> {
209 &None
210 }
211 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
212 None
213 }
214 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
215 &None
216 }
217 fn get_processor_config(&self) -> &Option<PathBuf> {
218 &None
219 }
220 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
221 &None
222 }
223 fn get_adapter_paths(&self) -> &AdapterPaths {
224 &self.adapter_paths
225 }
226 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
227 Some(&self.modules)
228 }
229}
230
231#[derive(Debug, Clone)]
232pub enum TokenSource {
234 Literal(String),
235 EnvVar(String),
236 Path(String),
237 CacheToken,
238 None,
239}
240
241impl FromStr for TokenSource {
242 type Err = String;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 let parts: Vec<&str> = s.splitn(2, ':').collect();
246 match parts[0] {
247 "literal" => parts
248 .get(1)
249 .map(|&value| TokenSource::Literal(value.to_string()))
250 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
251 "env" => Ok(TokenSource::EnvVar(
252 parts
253 .get(1)
254 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
255 .to_string(),
256 )),
257 "path" => parts
258 .get(1)
259 .map(|&value| TokenSource::Path(value.to_string()))
260 .ok_or_else(|| "Expected a value for 'path'".to_string()),
261 "cache" => Ok(TokenSource::CacheToken),
262 "none" => Ok(TokenSource::None),
263 _ => Err("Invalid token source format".to_string()),
264 }
265 }
266}
267
268impl fmt::Display for TokenSource {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 match self {
271 TokenSource::Literal(value) => write!(f, "literal:{value}"),
272 TokenSource::EnvVar(value) => write!(f, "env:{value}"),
273 TokenSource::Path(value) => write!(f, "path:{value}"),
274 TokenSource::CacheToken => write!(f, "cache"),
275 TokenSource::None => write!(f, "none"),
276 }
277 }
278}
279
280#[derive(Clone, Default, derive_more::From, strum::Display)]
282pub enum ModelKind {
283 #[default]
284 #[strum(to_string = "normal (no adapters)")]
285 Normal,
286
287 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
288 GgufQuantized { quant: QuantizationKind },
289
290 #[strum(to_string = "{adapter}")]
291 Adapter { adapter: AdapterKind },
292
293 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
294 GgufAdapter {
295 adapter: AdapterKind,
296 quant: QuantizationKind,
297 },
298
299 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
300 Speculative {
301 target: Box<ModelKind>,
302 draft: Box<ModelKind>,
303 },
304
305 #[strum(to_string = "anymoe: target: `{target}`")]
306 AnyMoe { target: Box<ModelKind> },
307}
308
309#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
310#[strum(serialize_all = "kebab-case")]
311pub enum QuantizationKind {
312 Ggml,
314 Gguf,
316 Gptq,
318}
319
320#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
321#[strum(serialize_all = "kebab-case")]
322pub enum AdapterKind {
323 Lora,
325 XLora,
327}
328
329pub trait PrettyName: strum::EnumMessage + ToString {
331 fn pretty_name(&self) -> String {
332 match self.get_documentation() {
333 Some(s) => s.to_string(),
334 None => self.to_string(),
337 }
338 }
339}
340
341impl PrettyName for AdapterKind {}
342impl PrettyName for QuantizationKind {}
343
344impl ModelKind {
345 pub fn is_quantized(&self) -> bool {
347 self.quantized_kind().iter().any(|q| q.is_some())
348 }
349
350 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
351 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
352 }
353
354 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
355 use ModelKind::*;
356
357 match self {
358 Normal | Adapter { .. } => vec![None],
359 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
360 Speculative { target, draft } => {
361 let t = *target.clone();
362 let d = *draft.clone();
363
364 [t.quantized_kind(), d.quantized_kind()].concat()
365 }
366 AnyMoe { target } => target.quantized_kind(),
367 }
368 }
369
370 pub fn is_adapted(&self) -> bool {
372 self.adapted_kind().iter().any(|a| a.is_some())
373 }
374
375 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
376 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
377 }
378
379 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
380 use ModelKind::*;
381
382 match self {
383 Normal | GgufQuantized { .. } => vec![None],
384 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
385 Speculative { target, draft } => {
386 let t = *target.clone();
387 let d = *draft.clone();
388
389 [t.adapted_kind(), d.adapted_kind()].concat()
390 }
391 AnyMoe { target } => target.adapted_kind(),
392 }
393 }
394}
395
396#[derive(Deserialize)]
397pub struct QuantizationConfigShim {
398 quantization_config: Option<QuantizedConfig>,
399}
400
401impl QuantizationConfigShim {
402 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
403 let QuantizationConfigShim {
404 quantization_config,
405 } = serde_json::from_str(config)?;
406
407 if let Some(quantization_config) = quantization_config {
408 Ok(quantization_config.pack_factor(dtype))
409 } else {
410 Ok(1)
411 }
412 }
413}
414
415pub trait DeviceMappedModelLoader {
416 fn non_mapped_max_act_size_elems(
419 &self,
420 config: &str,
421 params: &AutoDeviceMapParams,
422 ) -> Result<usize>;
423 fn mapped_max_act_size_elems(
425 &self,
426 config: &str,
427 params: &AutoDeviceMapParams,
428 ) -> Result<usize>;
429 fn non_mapped_size_in_bytes(
431 &self,
432 config: &str,
433 dtype: DType,
434 weight_pack_factor: usize,
435 matformer_config: Option<&MatformerSliceConfig>,
436 ) -> Result<usize>;
437 fn layer_sizes_in_bytes(
439 &self,
440 config: &str,
441 dtype: DType,
442 weight_pack_factor: usize,
443 matformer_config: Option<&MatformerSliceConfig>,
444 ) -> Result<Vec<usize>>;
445 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
446 None
447 }
448 fn num_layers(&self, config: &str) -> Result<usize>;
449 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
450
451 #[allow(clippy::too_many_arguments)]
452 fn get_device_layers(
453 &self,
454 config: &str,
455 num_layers: usize,
456 layer_sizes_in_bytes: Vec<usize>,
457 non_mapped_size_in_bytes: usize,
458 total_model_size_in_bytes: usize,
459 devices: &[Device],
460 dtype: DType,
461 params: &AutoDeviceMapParams,
462 paged_attn_config: Option<&PagedAttentionConfig>,
463 ) -> Result<DeviceMapMetadata>
464 where
465 Self: Sized,
466 {
467 auto_device_map::get_device_layers(
468 self,
469 config,
470 num_layers,
471 layer_sizes_in_bytes,
472 non_mapped_size_in_bytes,
473 total_model_size_in_bytes,
474 devices,
475 dtype,
476 params,
477 paged_attn_config,
478 )
479 }
480}
481
482pub trait Loader: Send + Sync {
503 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
507 fn load_model_from_hf(
508 &self,
509 revision: Option<String>,
510 token_source: TokenSource,
511 dtype: &dyn TryIntoDType,
512 device: &Device,
513 silent: bool,
514 mapper: DeviceMapSetting,
515 in_situ_quant: Option<IsqType>,
516 paged_attn_config: Option<PagedAttentionConfig>,
517 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
518
519 #[allow(
522 clippy::type_complexity,
523 clippy::too_many_arguments,
524 clippy::borrowed_box
525 )]
526 fn load_model_from_path(
527 &self,
528 paths: &Box<dyn ModelPaths>,
529 dtype: &dyn TryIntoDType,
530 device: &Device,
531 silent: bool,
532 mapper: DeviceMapSetting,
533 in_situ_quant: Option<IsqType>,
534 paged_attn_config: Option<PagedAttentionConfig>,
535 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
536
537 fn get_id(&self) -> String;
538 fn get_kind(&self) -> ModelKind;
539}