mistralrs_core/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod normal_loaders;
5mod vision_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10    fmt::{self, Debug},
11    path::PathBuf,
12    str::FromStr,
13    sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
25    GptOssLoader, GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader,
26    NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader, Phi2Loader,
27    Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader,
28    Starcoder2Loader,
29};
30
31pub use vision_loaders::{
32    AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
33    LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
34    Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, VLlama4Loader, VLlamaLoader,
35    VisionLoaderType, VisionModel, VisionModelLoader,
36};
37
38pub use embedding_loaders::{
39    AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
40    EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
41    Qwen3EmbeddingLoader,
42};
43
44pub use diffusion_loaders::{
45    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
46    DiffusionModelPathsInner, FluxLoader,
47};
48
49use crate::{
50    matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
51    DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
52};
53
54use super::{paths::AdapterPaths, Pipeline};
55
56/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
57/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
58pub trait ModelPaths: AsAny + Debug + Send + Sync {
59    /// Model weights files (multiple files supported).
60    fn get_weight_filenames(&self) -> &[PathBuf];
61
62    /// Retrieve the [`PretrainedConfig`] file.
63    ///
64    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
65    fn get_config_filename(&self) -> &PathBuf;
66
67    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
68    ///
69    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
70    fn get_tokenizer_filename(&self) -> &PathBuf;
71
72    /// File where the content is expected to deserialize to [`ChatTemplate`].
73    ///
74    /// [`ChatTemplate`]: crate::ChatTemplate
75    fn get_template_filename(&self) -> &Option<PathBuf>;
76
77    /// Filepath for general model configuration.
78    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
79
80    /// Get the preprocessor config (for the vision models). This is used to pre process images.
81    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
82
83    /// Get the processor config (for the vision models). This is primarily used for the chat template.
84    fn get_processor_config(&self) -> &Option<PathBuf>;
85
86    /// Get the explicit chat template.
87    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
88
89    /// Get adapter paths.
90    fn get_adapter_paths(&self) -> &AdapterPaths;
91
92    /// Get embedding model `modules.json` compatible with sentence-transformers
93    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
94}
95
96#[derive(Clone, Debug)]
97/// All local paths and metadata necessary to load a model.
98pub struct LocalModelPaths<P: Debug> {
99    pub tokenizer_filename: P,
100    pub config_filename: P,
101    pub template_filename: Option<P>,
102    pub filenames: Vec<P>,
103    pub adapter_paths: AdapterPaths,
104    pub gen_conf: Option<P>,
105    pub preprocessor_config: Option<P>,
106    pub processor_config: Option<P>,
107    pub chat_template_json_filename: Option<P>,
108}
109
110impl<P: Debug> LocalModelPaths<P> {
111    #[allow(clippy::too_many_arguments)]
112    pub fn new(
113        tokenizer_filename: P,
114        config_filename: P,
115        template_filename: P,
116        filenames: Vec<P>,
117        adapter_paths: AdapterPaths,
118        gen_conf: Option<P>,
119        preprocessor_config: Option<P>,
120        processor_config: Option<P>,
121        chat_template_json_filename: Option<P>,
122    ) -> Self {
123        Self {
124            tokenizer_filename,
125            config_filename,
126            template_filename: Some(template_filename),
127            filenames,
128            adapter_paths,
129            gen_conf,
130            preprocessor_config,
131            processor_config,
132            chat_template_json_filename,
133        }
134    }
135}
136
137impl ModelPaths for LocalModelPaths<PathBuf> {
138    fn get_config_filename(&self) -> &PathBuf {
139        &self.config_filename
140    }
141    fn get_tokenizer_filename(&self) -> &PathBuf {
142        &self.tokenizer_filename
143    }
144    fn get_weight_filenames(&self) -> &[PathBuf] {
145        &self.filenames
146    }
147    fn get_template_filename(&self) -> &Option<PathBuf> {
148        &self.template_filename
149    }
150    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
151        self.gen_conf.as_ref()
152    }
153    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
154        &self.preprocessor_config
155    }
156    fn get_processor_config(&self) -> &Option<PathBuf> {
157        &self.processor_config
158    }
159    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
160        &self.chat_template_json_filename
161    }
162    fn get_adapter_paths(&self) -> &AdapterPaths {
163        &self.adapter_paths
164    }
165    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
166        None
167    }
168}
169
170#[derive(Clone, Debug)]
171/// All local paths and metadata necessary to load an embedding model.
172pub struct EmbeddingModelPaths<P: Debug> {
173    pub tokenizer_filename: P,
174    pub config_filename: P,
175    pub modules: Vec<EmbeddingModulePaths>,
176    pub filenames: Vec<P>,
177    pub adapter_paths: AdapterPaths,
178}
179
180impl<P: Debug> EmbeddingModelPaths<P> {
181    #[allow(clippy::too_many_arguments)]
182    pub fn new(
183        tokenizer_filename: P,
184        config_filename: P,
185        filenames: Vec<P>,
186        adapter_paths: AdapterPaths,
187        modules: Vec<EmbeddingModulePaths>,
188    ) -> Self {
189        Self {
190            tokenizer_filename,
191            config_filename,
192            filenames,
193            adapter_paths,
194            modules,
195        }
196    }
197}
198
199impl ModelPaths for EmbeddingModelPaths<PathBuf> {
200    fn get_config_filename(&self) -> &PathBuf {
201        &self.config_filename
202    }
203    fn get_tokenizer_filename(&self) -> &PathBuf {
204        &self.tokenizer_filename
205    }
206    fn get_weight_filenames(&self) -> &[PathBuf] {
207        &self.filenames
208    }
209    fn get_template_filename(&self) -> &Option<PathBuf> {
210        &None
211    }
212    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
213        None
214    }
215    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
216        &None
217    }
218    fn get_processor_config(&self) -> &Option<PathBuf> {
219        &None
220    }
221    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
222        &None
223    }
224    fn get_adapter_paths(&self) -> &AdapterPaths {
225        &self.adapter_paths
226    }
227    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
228        Some(&self.modules)
229    }
230}
231
232#[derive(Debug, Clone)]
233/// The source of the HF token.
234pub enum TokenSource {
235    Literal(String),
236    EnvVar(String),
237    Path(String),
238    CacheToken,
239    None,
240}
241
242impl FromStr for TokenSource {
243    type Err = String;
244
245    fn from_str(s: &str) -> Result<Self, Self::Err> {
246        let parts: Vec<&str> = s.splitn(2, ':').collect();
247        match parts[0] {
248            "literal" => parts
249                .get(1)
250                .map(|&value| TokenSource::Literal(value.to_string()))
251                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
252            "env" => Ok(TokenSource::EnvVar(
253                parts
254                    .get(1)
255                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
256                    .to_string(),
257            )),
258            "path" => parts
259                .get(1)
260                .map(|&value| TokenSource::Path(value.to_string()))
261                .ok_or_else(|| "Expected a value for 'path'".to_string()),
262            "cache" => Ok(TokenSource::CacheToken),
263            "none" => Ok(TokenSource::None),
264            _ => Err("Invalid token source format".to_string()),
265        }
266    }
267}
268
269impl fmt::Display for TokenSource {
270    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271        match self {
272            TokenSource::Literal(value) => write!(f, "literal:{value}"),
273            TokenSource::EnvVar(value) => write!(f, "env:{value}"),
274            TokenSource::Path(value) => write!(f, "path:{value}"),
275            TokenSource::CacheToken => write!(f, "cache"),
276            TokenSource::None => write!(f, "none"),
277        }
278    }
279}
280
281/// The kind of model to build.
282#[derive(Clone, Default, derive_more::From, strum::Display)]
283pub enum ModelKind {
284    #[default]
285    #[strum(to_string = "normal (no adapters)")]
286    Normal,
287
288    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
289    GgufQuantized { quant: QuantizationKind },
290
291    #[strum(to_string = "{adapter}")]
292    Adapter { adapter: AdapterKind },
293
294    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
295    GgufAdapter {
296        adapter: AdapterKind,
297        quant: QuantizationKind,
298    },
299
300    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
301    Speculative {
302        target: Box<ModelKind>,
303        draft: Box<ModelKind>,
304    },
305
306    #[strum(to_string = "anymoe: target: `{target}`")]
307    AnyMoe { target: Box<ModelKind> },
308}
309
310#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
311#[strum(serialize_all = "kebab-case")]
312pub enum QuantizationKind {
313    /// GGML
314    Ggml,
315    /// GGUF
316    Gguf,
317    /// GPTQ
318    Gptq,
319}
320
321#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
322#[strum(serialize_all = "kebab-case")]
323pub enum AdapterKind {
324    /// LoRA
325    Lora,
326    /// X-LoRA
327    XLora,
328}
329
330// For the proper name as formatted via doc comment for a variant
331pub trait PrettyName: strum::EnumMessage + ToString {
332    fn pretty_name(&self) -> String {
333        match self.get_documentation() {
334            Some(s) => s.to_string(),
335            // Instead of panic via expect(),
336            // fallback to default kebab-case:
337            None => self.to_string(),
338        }
339    }
340}
341
342impl PrettyName for AdapterKind {}
343impl PrettyName for QuantizationKind {}
344
345impl ModelKind {
346    // Quantized helpers:
347    pub fn is_quantized(&self) -> bool {
348        self.quantized_kind().iter().any(|q| q.is_some())
349    }
350
351    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
352        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
353    }
354
355    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
356        use ModelKind::*;
357
358        match self {
359            Normal | Adapter { .. } => vec![None],
360            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
361            Speculative { target, draft } => {
362                let t = *target.clone();
363                let d = *draft.clone();
364
365                [t.quantized_kind(), d.quantized_kind()].concat()
366            }
367            AnyMoe { target } => target.quantized_kind(),
368        }
369    }
370
371    // Adapter helpers:
372    pub fn is_adapted(&self) -> bool {
373        self.adapted_kind().iter().any(|a| a.is_some())
374    }
375
376    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
377        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
378    }
379
380    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
381        use ModelKind::*;
382
383        match self {
384            Normal | GgufQuantized { .. } => vec![None],
385            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
386            Speculative { target, draft } => {
387                let t = *target.clone();
388                let d = *draft.clone();
389
390                [t.adapted_kind(), d.adapted_kind()].concat()
391            }
392            AnyMoe { target } => target.adapted_kind(),
393        }
394    }
395}
396
397#[derive(Deserialize)]
398pub struct QuantizationConfigShim {
399    quantization_config: Option<QuantizedConfig>,
400}
401
402impl QuantizationConfigShim {
403    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
404        let QuantizationConfigShim {
405            quantization_config,
406        } = serde_json::from_str(config)?;
407
408        if let Some(quantization_config) = quantization_config {
409            Ok(quantization_config.pack_factor(dtype))
410        } else {
411            Ok(1)
412        }
413    }
414}
415
416pub trait DeviceMappedModelLoader {
417    /// Maximum activation size of non-mapped parts of this model.
418    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
419    fn non_mapped_max_act_size_elems(
420        &self,
421        config: &str,
422        params: &AutoDeviceMapParams,
423    ) -> Result<usize>;
424    /// Maximum activation size of mapped parts of the model
425    fn mapped_max_act_size_elems(
426        &self,
427        config: &str,
428        params: &AutoDeviceMapParams,
429    ) -> Result<usize>;
430    /// weight_pack_factor only applies to quantized weights.
431    fn non_mapped_size_in_bytes(
432        &self,
433        config: &str,
434        dtype: DType,
435        weight_pack_factor: usize,
436        matformer_config: Option<&MatformerSliceConfig>,
437    ) -> Result<usize>;
438    /// weight_pack_factor only applies to quantized weights.
439    fn layer_sizes_in_bytes(
440        &self,
441        config: &str,
442        dtype: DType,
443        weight_pack_factor: usize,
444        matformer_config: Option<&MatformerSliceConfig>,
445    ) -> Result<Vec<usize>>;
446    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
447        None
448    }
449    fn num_layers(&self, config: &str) -> Result<usize>;
450    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
451
452    #[allow(clippy::too_many_arguments)]
453    fn get_device_layers(
454        &self,
455        config: &str,
456        num_layers: usize,
457        layer_sizes_in_bytes: Vec<usize>,
458        non_mapped_size_in_bytes: usize,
459        total_model_size_in_bytes: usize,
460        devices: &[Device],
461        dtype: DType,
462        params: &AutoDeviceMapParams,
463        paged_attn_config: Option<&PagedAttentionConfig>,
464    ) -> Result<DeviceMapMetadata>
465    where
466        Self: Sized,
467    {
468        auto_device_map::get_device_layers(
469            self,
470            config,
471            num_layers,
472            layer_sizes_in_bytes,
473            non_mapped_size_in_bytes,
474            total_model_size_in_bytes,
475            devices,
476            dtype,
477            params,
478            paged_attn_config,
479        )
480    }
481}
482
483/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
484/// `load_model` method.
485///
486/// # Example
487/// ```no_run
488/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
489/// use candle_core::Device;
490///
491/// let loader: Box<dyn Loader> = todo!();
492/// let pipeline = loader.load_model_from_hf(
493///     None,
494///     TokenSource::CacheToken,
495///     &ModelDType::Auto,
496///     &Device::cuda_if_available(0).unwrap(),
497///     false,
498///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
499///     None,
500///     None,
501/// ).unwrap();
502/// ```
503pub trait Loader: Send + Sync {
504    /// If `revision` is None, then it defaults to `main`.
505    /// If `dtype` is None, then it defaults to the model default (usually BF16).
506    /// If model is not found on HF, will attempt to resolve locally.
507    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
508    fn load_model_from_hf(
509        &self,
510        revision: Option<String>,
511        token_source: TokenSource,
512        dtype: &dyn TryIntoDType,
513        device: &Device,
514        silent: bool,
515        mapper: DeviceMapSetting,
516        in_situ_quant: Option<IsqType>,
517        paged_attn_config: Option<PagedAttentionConfig>,
518    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
519
520    /// Load a model from the specified paths.
521    /// Also initializes `DEBUG`.
522    #[allow(
523        clippy::type_complexity,
524        clippy::too_many_arguments,
525        clippy::borrowed_box
526    )]
527    fn load_model_from_path(
528        &self,
529        paths: &Box<dyn ModelPaths>,
530        dtype: &dyn TryIntoDType,
531        device: &Device,
532        silent: bool,
533        mapper: DeviceMapSetting,
534        in_situ_quant: Option<IsqType>,
535        paged_attn_config: Option<PagedAttentionConfig>,
536    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
537
538    fn get_id(&self) -> String;
539    fn get_kind(&self) -> ModelKind;
540}