mistralrs_core/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod normal_loaders;
5mod vision_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10    fmt::{self, Debug},
11    path::PathBuf,
12    str::FromStr,
13    sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
25    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
26    NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
27    Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
28};
29
30pub use vision_loaders::{
31    AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
32    LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
33    Qwen2_5VLLoader, Qwen3VLLoader, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
34    VisionModelLoader,
35};
36
37pub use embedding_loaders::{
38    AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
39    EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
40    Qwen3EmbeddingLoader,
41};
42
43pub use diffusion_loaders::{
44    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
45    DiffusionModelPathsInner, FluxLoader,
46};
47
48use crate::{
49    matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
50    DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
51};
52
53use super::{paths::AdapterPaths, Pipeline};
54
55/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
56/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
57pub trait ModelPaths: AsAny + Debug + Send + Sync {
58    /// Model weights files (multiple files supported).
59    fn get_weight_filenames(&self) -> &[PathBuf];
60
61    /// Retrieve the [`PretrainedConfig`] file.
62    ///
63    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
64    fn get_config_filename(&self) -> &PathBuf;
65
66    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
67    ///
68    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
69    fn get_tokenizer_filename(&self) -> &PathBuf;
70
71    /// File where the content is expected to deserialize to [`ChatTemplate`].
72    ///
73    /// [`ChatTemplate`]: crate::ChatTemplate
74    fn get_template_filename(&self) -> &Option<PathBuf>;
75
76    /// Filepath for general model configuration.
77    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
78
79    /// Get the preprocessor config (for the vision models). This is used to pre process images.
80    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
81
82    /// Get the processor config (for the vision models). This is primarily used for the chat template.
83    fn get_processor_config(&self) -> &Option<PathBuf>;
84
85    /// Get the explicit chat template.
86    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
87
88    /// Get adapter paths.
89    fn get_adapter_paths(&self) -> &AdapterPaths;
90
91    /// Get embedding model `modules.json` compatible with sentence-transformers
92    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
93}
94
95#[derive(Clone, Debug)]
96/// All local paths and metadata necessary to load a model.
97pub struct LocalModelPaths<P: Debug> {
98    pub tokenizer_filename: P,
99    pub config_filename: P,
100    pub template_filename: Option<P>,
101    pub filenames: Vec<P>,
102    pub adapter_paths: AdapterPaths,
103    pub gen_conf: Option<P>,
104    pub preprocessor_config: Option<P>,
105    pub processor_config: Option<P>,
106    pub chat_template_json_filename: Option<P>,
107}
108
109impl<P: Debug> LocalModelPaths<P> {
110    #[allow(clippy::too_many_arguments)]
111    pub fn new(
112        tokenizer_filename: P,
113        config_filename: P,
114        template_filename: P,
115        filenames: Vec<P>,
116        adapter_paths: AdapterPaths,
117        gen_conf: Option<P>,
118        preprocessor_config: Option<P>,
119        processor_config: Option<P>,
120        chat_template_json_filename: Option<P>,
121    ) -> Self {
122        Self {
123            tokenizer_filename,
124            config_filename,
125            template_filename: Some(template_filename),
126            filenames,
127            adapter_paths,
128            gen_conf,
129            preprocessor_config,
130            processor_config,
131            chat_template_json_filename,
132        }
133    }
134}
135
136impl ModelPaths for LocalModelPaths<PathBuf> {
137    fn get_config_filename(&self) -> &PathBuf {
138        &self.config_filename
139    }
140    fn get_tokenizer_filename(&self) -> &PathBuf {
141        &self.tokenizer_filename
142    }
143    fn get_weight_filenames(&self) -> &[PathBuf] {
144        &self.filenames
145    }
146    fn get_template_filename(&self) -> &Option<PathBuf> {
147        &self.template_filename
148    }
149    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
150        self.gen_conf.as_ref()
151    }
152    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
153        &self.preprocessor_config
154    }
155    fn get_processor_config(&self) -> &Option<PathBuf> {
156        &self.processor_config
157    }
158    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
159        &self.chat_template_json_filename
160    }
161    fn get_adapter_paths(&self) -> &AdapterPaths {
162        &self.adapter_paths
163    }
164    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
165        None
166    }
167}
168
169#[derive(Clone, Debug)]
170/// All local paths and metadata necessary to load an embedding model.
171pub struct EmbeddingModelPaths<P: Debug> {
172    pub tokenizer_filename: P,
173    pub config_filename: P,
174    pub modules: Vec<EmbeddingModulePaths>,
175    pub filenames: Vec<P>,
176    pub adapter_paths: AdapterPaths,
177}
178
179impl<P: Debug> EmbeddingModelPaths<P> {
180    #[allow(clippy::too_many_arguments)]
181    pub fn new(
182        tokenizer_filename: P,
183        config_filename: P,
184        filenames: Vec<P>,
185        adapter_paths: AdapterPaths,
186        modules: Vec<EmbeddingModulePaths>,
187    ) -> Self {
188        Self {
189            tokenizer_filename,
190            config_filename,
191            filenames,
192            adapter_paths,
193            modules,
194        }
195    }
196}
197
198impl ModelPaths for EmbeddingModelPaths<PathBuf> {
199    fn get_config_filename(&self) -> &PathBuf {
200        &self.config_filename
201    }
202    fn get_tokenizer_filename(&self) -> &PathBuf {
203        &self.tokenizer_filename
204    }
205    fn get_weight_filenames(&self) -> &[PathBuf] {
206        &self.filenames
207    }
208    fn get_template_filename(&self) -> &Option<PathBuf> {
209        &None
210    }
211    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
212        None
213    }
214    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
215        &None
216    }
217    fn get_processor_config(&self) -> &Option<PathBuf> {
218        &None
219    }
220    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
221        &None
222    }
223    fn get_adapter_paths(&self) -> &AdapterPaths {
224        &self.adapter_paths
225    }
226    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
227        Some(&self.modules)
228    }
229}
230
231#[derive(Debug, Clone)]
232/// The source of the HF token.
233pub enum TokenSource {
234    Literal(String),
235    EnvVar(String),
236    Path(String),
237    CacheToken,
238    None,
239}
240
241impl FromStr for TokenSource {
242    type Err = String;
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        let parts: Vec<&str> = s.splitn(2, ':').collect();
246        match parts[0] {
247            "literal" => parts
248                .get(1)
249                .map(|&value| TokenSource::Literal(value.to_string()))
250                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
251            "env" => Ok(TokenSource::EnvVar(
252                parts
253                    .get(1)
254                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
255                    .to_string(),
256            )),
257            "path" => parts
258                .get(1)
259                .map(|&value| TokenSource::Path(value.to_string()))
260                .ok_or_else(|| "Expected a value for 'path'".to_string()),
261            "cache" => Ok(TokenSource::CacheToken),
262            "none" => Ok(TokenSource::None),
263            _ => Err("Invalid token source format".to_string()),
264        }
265    }
266}
267
268impl fmt::Display for TokenSource {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        match self {
271            TokenSource::Literal(value) => write!(f, "literal:{value}"),
272            TokenSource::EnvVar(value) => write!(f, "env:{value}"),
273            TokenSource::Path(value) => write!(f, "path:{value}"),
274            TokenSource::CacheToken => write!(f, "cache"),
275            TokenSource::None => write!(f, "none"),
276        }
277    }
278}
279
280/// The kind of model to build.
281#[derive(Clone, Default, derive_more::From, strum::Display)]
282pub enum ModelKind {
283    #[default]
284    #[strum(to_string = "normal (no adapters)")]
285    Normal,
286
287    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
288    GgufQuantized { quant: QuantizationKind },
289
290    #[strum(to_string = "{adapter}")]
291    Adapter { adapter: AdapterKind },
292
293    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
294    GgufAdapter {
295        adapter: AdapterKind,
296        quant: QuantizationKind,
297    },
298
299    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
300    Speculative {
301        target: Box<ModelKind>,
302        draft: Box<ModelKind>,
303    },
304
305    #[strum(to_string = "anymoe: target: `{target}`")]
306    AnyMoe { target: Box<ModelKind> },
307}
308
309#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
310#[strum(serialize_all = "kebab-case")]
311pub enum QuantizationKind {
312    /// GGML
313    Ggml,
314    /// GGUF
315    Gguf,
316    /// GPTQ
317    Gptq,
318}
319
320#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
321#[strum(serialize_all = "kebab-case")]
322pub enum AdapterKind {
323    /// LoRA
324    Lora,
325    /// X-LoRA
326    XLora,
327}
328
329// For the proper name as formatted via doc comment for a variant
330pub trait PrettyName: strum::EnumMessage + ToString {
331    fn pretty_name(&self) -> String {
332        match self.get_documentation() {
333            Some(s) => s.to_string(),
334            // Instead of panic via expect(),
335            // fallback to default kebab-case:
336            None => self.to_string(),
337        }
338    }
339}
340
341impl PrettyName for AdapterKind {}
342impl PrettyName for QuantizationKind {}
343
344impl ModelKind {
345    // Quantized helpers:
346    pub fn is_quantized(&self) -> bool {
347        self.quantized_kind().iter().any(|q| q.is_some())
348    }
349
350    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
351        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
352    }
353
354    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
355        use ModelKind::*;
356
357        match self {
358            Normal | Adapter { .. } => vec![None],
359            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
360            Speculative { target, draft } => {
361                let t = *target.clone();
362                let d = *draft.clone();
363
364                [t.quantized_kind(), d.quantized_kind()].concat()
365            }
366            AnyMoe { target } => target.quantized_kind(),
367        }
368    }
369
370    // Adapter helpers:
371    pub fn is_adapted(&self) -> bool {
372        self.adapted_kind().iter().any(|a| a.is_some())
373    }
374
375    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
376        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
377    }
378
379    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
380        use ModelKind::*;
381
382        match self {
383            Normal | GgufQuantized { .. } => vec![None],
384            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
385            Speculative { target, draft } => {
386                let t = *target.clone();
387                let d = *draft.clone();
388
389                [t.adapted_kind(), d.adapted_kind()].concat()
390            }
391            AnyMoe { target } => target.adapted_kind(),
392        }
393    }
394}
395
396#[derive(Deserialize)]
397pub struct QuantizationConfigShim {
398    quantization_config: Option<QuantizedConfig>,
399}
400
401impl QuantizationConfigShim {
402    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
403        let QuantizationConfigShim {
404            quantization_config,
405        } = serde_json::from_str(config)?;
406
407        if let Some(quantization_config) = quantization_config {
408            Ok(quantization_config.pack_factor(dtype))
409        } else {
410            Ok(1)
411        }
412    }
413}
414
415pub trait DeviceMappedModelLoader {
416    /// Maximum activation size of non-mapped parts of this model.
417    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
418    fn non_mapped_max_act_size_elems(
419        &self,
420        config: &str,
421        params: &AutoDeviceMapParams,
422    ) -> Result<usize>;
423    /// Maximum activation size of mapped parts of the model
424    fn mapped_max_act_size_elems(
425        &self,
426        config: &str,
427        params: &AutoDeviceMapParams,
428    ) -> Result<usize>;
429    /// weight_pack_factor only applies to quantized weights.
430    fn non_mapped_size_in_bytes(
431        &self,
432        config: &str,
433        dtype: DType,
434        weight_pack_factor: usize,
435        matformer_config: Option<&MatformerSliceConfig>,
436    ) -> Result<usize>;
437    /// weight_pack_factor only applies to quantized weights.
438    fn layer_sizes_in_bytes(
439        &self,
440        config: &str,
441        dtype: DType,
442        weight_pack_factor: usize,
443        matformer_config: Option<&MatformerSliceConfig>,
444    ) -> Result<Vec<usize>>;
445    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
446        None
447    }
448    fn num_layers(&self, config: &str) -> Result<usize>;
449    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
450
451    #[allow(clippy::too_many_arguments)]
452    fn get_device_layers(
453        &self,
454        config: &str,
455        num_layers: usize,
456        layer_sizes_in_bytes: Vec<usize>,
457        non_mapped_size_in_bytes: usize,
458        total_model_size_in_bytes: usize,
459        devices: &[Device],
460        dtype: DType,
461        params: &AutoDeviceMapParams,
462        paged_attn_config: Option<&PagedAttentionConfig>,
463    ) -> Result<DeviceMapMetadata>
464    where
465        Self: Sized,
466    {
467        auto_device_map::get_device_layers(
468            self,
469            config,
470            num_layers,
471            layer_sizes_in_bytes,
472            non_mapped_size_in_bytes,
473            total_model_size_in_bytes,
474            devices,
475            dtype,
476            params,
477            paged_attn_config,
478        )
479    }
480}
481
482/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
483/// `load_model` method.
484///
485/// # Example
486/// ```no_run
487/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
488/// use candle_core::Device;
489///
490/// let loader: Box<dyn Loader> = todo!();
491/// let pipeline = loader.load_model_from_hf(
492///     None,
493///     TokenSource::CacheToken,
494///     &ModelDType::Auto,
495///     &Device::cuda_if_available(0).unwrap(),
496///     false,
497///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
498///     None,
499///     None,
500/// ).unwrap();
501/// ```
502pub trait Loader: Send + Sync {
503    /// If `revision` is None, then it defaults to `main`.
504    /// If `dtype` is None, then it defaults to the model default (usually BF16).
505    /// If model is not found on HF, will attempt to resolve locally.
506    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
507    fn load_model_from_hf(
508        &self,
509        revision: Option<String>,
510        token_source: TokenSource,
511        dtype: &dyn TryIntoDType,
512        device: &Device,
513        silent: bool,
514        mapper: DeviceMapSetting,
515        in_situ_quant: Option<IsqType>,
516        paged_attn_config: Option<PagedAttentionConfig>,
517    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
518
519    /// Load a model from the specified paths.
520    /// Also initializes `DEBUG`.
521    #[allow(
522        clippy::type_complexity,
523        clippy::too_many_arguments,
524        clippy::borrowed_box
525    )]
526    fn load_model_from_path(
527        &self,
528        paths: &Box<dyn ModelPaths>,
529        dtype: &dyn TryIntoDType,
530        device: &Device,
531        silent: bool,
532        mapper: DeviceMapSetting,
533        in_situ_quant: Option<IsqType>,
534        paged_attn_config: Option<PagedAttentionConfig>,
535    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
536
537    fn get_id(&self) -> String;
538    fn get_kind(&self) -> ModelKind;
539}