1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod normal_loaders;
5mod vision_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10 fmt::{self, Debug},
11 path::PathBuf,
12 str::FromStr,
13 sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
25 GptOssLoader, GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader,
26 NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader, Phi2Loader,
27 Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader,
28 Starcoder2Loader,
29};
30
31pub use vision_loaders::{
32 AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
33 LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
34 Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, VLlama4Loader, VLlamaLoader,
35 VisionLoaderType, VisionModel, VisionModelLoader,
36};
37
38pub use embedding_loaders::{
39 AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
40 EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
41 Qwen3EmbeddingLoader,
42};
43
44pub use diffusion_loaders::{
45 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
46 DiffusionModelPathsInner, FluxLoader,
47};
48
49use crate::{
50 matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
51 DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
52};
53
54use super::{paths::AdapterPaths, Pipeline};
55
56pub trait ModelPaths: AsAny + Debug + Send + Sync {
59 fn get_weight_filenames(&self) -> &[PathBuf];
61
62 fn get_config_filename(&self) -> &PathBuf;
66
67 fn get_tokenizer_filename(&self) -> &PathBuf;
71
72 fn get_template_filename(&self) -> &Option<PathBuf>;
76
77 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
79
80 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
82
83 fn get_processor_config(&self) -> &Option<PathBuf>;
85
86 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
88
89 fn get_adapter_paths(&self) -> &AdapterPaths;
91
92 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
94}
95
96#[derive(Clone, Debug)]
97pub struct LocalModelPaths<P: Debug> {
99 pub tokenizer_filename: P,
100 pub config_filename: P,
101 pub template_filename: Option<P>,
102 pub filenames: Vec<P>,
103 pub adapter_paths: AdapterPaths,
104 pub gen_conf: Option<P>,
105 pub preprocessor_config: Option<P>,
106 pub processor_config: Option<P>,
107 pub chat_template_json_filename: Option<P>,
108}
109
110impl<P: Debug> LocalModelPaths<P> {
111 #[allow(clippy::too_many_arguments)]
112 pub fn new(
113 tokenizer_filename: P,
114 config_filename: P,
115 template_filename: P,
116 filenames: Vec<P>,
117 adapter_paths: AdapterPaths,
118 gen_conf: Option<P>,
119 preprocessor_config: Option<P>,
120 processor_config: Option<P>,
121 chat_template_json_filename: Option<P>,
122 ) -> Self {
123 Self {
124 tokenizer_filename,
125 config_filename,
126 template_filename: Some(template_filename),
127 filenames,
128 adapter_paths,
129 gen_conf,
130 preprocessor_config,
131 processor_config,
132 chat_template_json_filename,
133 }
134 }
135}
136
137impl ModelPaths for LocalModelPaths<PathBuf> {
138 fn get_config_filename(&self) -> &PathBuf {
139 &self.config_filename
140 }
141 fn get_tokenizer_filename(&self) -> &PathBuf {
142 &self.tokenizer_filename
143 }
144 fn get_weight_filenames(&self) -> &[PathBuf] {
145 &self.filenames
146 }
147 fn get_template_filename(&self) -> &Option<PathBuf> {
148 &self.template_filename
149 }
150 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
151 self.gen_conf.as_ref()
152 }
153 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
154 &self.preprocessor_config
155 }
156 fn get_processor_config(&self) -> &Option<PathBuf> {
157 &self.processor_config
158 }
159 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
160 &self.chat_template_json_filename
161 }
162 fn get_adapter_paths(&self) -> &AdapterPaths {
163 &self.adapter_paths
164 }
165 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
166 None
167 }
168}
169
170#[derive(Clone, Debug)]
171pub struct EmbeddingModelPaths<P: Debug> {
173 pub tokenizer_filename: P,
174 pub config_filename: P,
175 pub modules: Vec<EmbeddingModulePaths>,
176 pub filenames: Vec<P>,
177 pub adapter_paths: AdapterPaths,
178}
179
180impl<P: Debug> EmbeddingModelPaths<P> {
181 #[allow(clippy::too_many_arguments)]
182 pub fn new(
183 tokenizer_filename: P,
184 config_filename: P,
185 filenames: Vec<P>,
186 adapter_paths: AdapterPaths,
187 modules: Vec<EmbeddingModulePaths>,
188 ) -> Self {
189 Self {
190 tokenizer_filename,
191 config_filename,
192 filenames,
193 adapter_paths,
194 modules,
195 }
196 }
197}
198
199impl ModelPaths for EmbeddingModelPaths<PathBuf> {
200 fn get_config_filename(&self) -> &PathBuf {
201 &self.config_filename
202 }
203 fn get_tokenizer_filename(&self) -> &PathBuf {
204 &self.tokenizer_filename
205 }
206 fn get_weight_filenames(&self) -> &[PathBuf] {
207 &self.filenames
208 }
209 fn get_template_filename(&self) -> &Option<PathBuf> {
210 &None
211 }
212 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
213 None
214 }
215 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
216 &None
217 }
218 fn get_processor_config(&self) -> &Option<PathBuf> {
219 &None
220 }
221 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
222 &None
223 }
224 fn get_adapter_paths(&self) -> &AdapterPaths {
225 &self.adapter_paths
226 }
227 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
228 Some(&self.modules)
229 }
230}
231
232#[derive(Debug, Clone)]
233pub enum TokenSource {
235 Literal(String),
236 EnvVar(String),
237 Path(String),
238 CacheToken,
239 None,
240}
241
242impl FromStr for TokenSource {
243 type Err = String;
244
245 fn from_str(s: &str) -> Result<Self, Self::Err> {
246 let parts: Vec<&str> = s.splitn(2, ':').collect();
247 match parts[0] {
248 "literal" => parts
249 .get(1)
250 .map(|&value| TokenSource::Literal(value.to_string()))
251 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
252 "env" => Ok(TokenSource::EnvVar(
253 parts
254 .get(1)
255 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
256 .to_string(),
257 )),
258 "path" => parts
259 .get(1)
260 .map(|&value| TokenSource::Path(value.to_string()))
261 .ok_or_else(|| "Expected a value for 'path'".to_string()),
262 "cache" => Ok(TokenSource::CacheToken),
263 "none" => Ok(TokenSource::None),
264 _ => Err("Invalid token source format".to_string()),
265 }
266 }
267}
268
269impl fmt::Display for TokenSource {
270 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271 match self {
272 TokenSource::Literal(value) => write!(f, "literal:{value}"),
273 TokenSource::EnvVar(value) => write!(f, "env:{value}"),
274 TokenSource::Path(value) => write!(f, "path:{value}"),
275 TokenSource::CacheToken => write!(f, "cache"),
276 TokenSource::None => write!(f, "none"),
277 }
278 }
279}
280
281#[derive(Clone, Default, derive_more::From, strum::Display)]
283pub enum ModelKind {
284 #[default]
285 #[strum(to_string = "normal (no adapters)")]
286 Normal,
287
288 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
289 GgufQuantized { quant: QuantizationKind },
290
291 #[strum(to_string = "{adapter}")]
292 Adapter { adapter: AdapterKind },
293
294 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
295 GgufAdapter {
296 adapter: AdapterKind,
297 quant: QuantizationKind,
298 },
299
300 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
301 Speculative {
302 target: Box<ModelKind>,
303 draft: Box<ModelKind>,
304 },
305
306 #[strum(to_string = "anymoe: target: `{target}`")]
307 AnyMoe { target: Box<ModelKind> },
308}
309
310#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
311#[strum(serialize_all = "kebab-case")]
312pub enum QuantizationKind {
313 Ggml,
315 Gguf,
317 Gptq,
319}
320
321#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
322#[strum(serialize_all = "kebab-case")]
323pub enum AdapterKind {
324 Lora,
326 XLora,
328}
329
330pub trait PrettyName: strum::EnumMessage + ToString {
332 fn pretty_name(&self) -> String {
333 match self.get_documentation() {
334 Some(s) => s.to_string(),
335 None => self.to_string(),
338 }
339 }
340}
341
342impl PrettyName for AdapterKind {}
343impl PrettyName for QuantizationKind {}
344
345impl ModelKind {
346 pub fn is_quantized(&self) -> bool {
348 self.quantized_kind().iter().any(|q| q.is_some())
349 }
350
351 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
352 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
353 }
354
355 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
356 use ModelKind::*;
357
358 match self {
359 Normal | Adapter { .. } => vec![None],
360 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
361 Speculative { target, draft } => {
362 let t = *target.clone();
363 let d = *draft.clone();
364
365 [t.quantized_kind(), d.quantized_kind()].concat()
366 }
367 AnyMoe { target } => target.quantized_kind(),
368 }
369 }
370
371 pub fn is_adapted(&self) -> bool {
373 self.adapted_kind().iter().any(|a| a.is_some())
374 }
375
376 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
377 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
378 }
379
380 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
381 use ModelKind::*;
382
383 match self {
384 Normal | GgufQuantized { .. } => vec![None],
385 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
386 Speculative { target, draft } => {
387 let t = *target.clone();
388 let d = *draft.clone();
389
390 [t.adapted_kind(), d.adapted_kind()].concat()
391 }
392 AnyMoe { target } => target.adapted_kind(),
393 }
394 }
395}
396
397#[derive(Deserialize)]
398pub struct QuantizationConfigShim {
399 quantization_config: Option<QuantizedConfig>,
400}
401
402impl QuantizationConfigShim {
403 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
404 let QuantizationConfigShim {
405 quantization_config,
406 } = serde_json::from_str(config)?;
407
408 if let Some(quantization_config) = quantization_config {
409 Ok(quantization_config.pack_factor(dtype))
410 } else {
411 Ok(1)
412 }
413 }
414}
415
416pub trait DeviceMappedModelLoader {
417 fn non_mapped_max_act_size_elems(
420 &self,
421 config: &str,
422 params: &AutoDeviceMapParams,
423 ) -> Result<usize>;
424 fn mapped_max_act_size_elems(
426 &self,
427 config: &str,
428 params: &AutoDeviceMapParams,
429 ) -> Result<usize>;
430 fn non_mapped_size_in_bytes(
432 &self,
433 config: &str,
434 dtype: DType,
435 weight_pack_factor: usize,
436 matformer_config: Option<&MatformerSliceConfig>,
437 ) -> Result<usize>;
438 fn layer_sizes_in_bytes(
440 &self,
441 config: &str,
442 dtype: DType,
443 weight_pack_factor: usize,
444 matformer_config: Option<&MatformerSliceConfig>,
445 ) -> Result<Vec<usize>>;
446 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
447 None
448 }
449 fn num_layers(&self, config: &str) -> Result<usize>;
450 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
451
452 #[allow(clippy::too_many_arguments)]
453 fn get_device_layers(
454 &self,
455 config: &str,
456 num_layers: usize,
457 layer_sizes_in_bytes: Vec<usize>,
458 non_mapped_size_in_bytes: usize,
459 total_model_size_in_bytes: usize,
460 devices: &[Device],
461 dtype: DType,
462 params: &AutoDeviceMapParams,
463 paged_attn_config: Option<&PagedAttentionConfig>,
464 ) -> Result<DeviceMapMetadata>
465 where
466 Self: Sized,
467 {
468 auto_device_map::get_device_layers(
469 self,
470 config,
471 num_layers,
472 layer_sizes_in_bytes,
473 non_mapped_size_in_bytes,
474 total_model_size_in_bytes,
475 devices,
476 dtype,
477 params,
478 paged_attn_config,
479 )
480 }
481}
482
483pub trait Loader: Send + Sync {
504 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
508 fn load_model_from_hf(
509 &self,
510 revision: Option<String>,
511 token_source: TokenSource,
512 dtype: &dyn TryIntoDType,
513 device: &Device,
514 silent: bool,
515 mapper: DeviceMapSetting,
516 in_situ_quant: Option<IsqType>,
517 paged_attn_config: Option<PagedAttentionConfig>,
518 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
519
520 #[allow(
523 clippy::type_complexity,
524 clippy::too_many_arguments,
525 clippy::borrowed_box
526 )]
527 fn load_model_from_path(
528 &self,
529 paths: &Box<dyn ModelPaths>,
530 dtype: &dyn TryIntoDType,
531 device: &Device,
532 silent: bool,
533 mapper: DeviceMapSetting,
534 in_situ_quant: Option<IsqType>,
535 paged_attn_config: Option<PagedAttentionConfig>,
536 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
537
538 fn get_id(&self) -> String;
539 fn get_kind(&self) -> ModelKind;
540}