mistralrs_core/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod normal_loaders;
4mod vision_loaders;
5pub use auto_device_map::AutoDeviceMapParams;
6use auto_device_map::NonMappedSubModel;
7
8use std::{
9    fmt::{self, Debug},
10    path::PathBuf,
11    str::FromStr,
12    sync::Arc,
13};
14
15use anyhow::Result;
16use as_any::AsAny;
17use candle_core::{DType, Device};
18use mistralrs_quant::{IsqType, QuantizedConfig};
19use serde::Deserialize;
20use tokio::sync::Mutex;
21
22pub use normal_loaders::{
23    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
24    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
25    NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
26    Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
27};
28
29pub use vision_loaders::{
30    AutoVisionLoader, Gemma3Loader, Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
31    LLaVANextLoader, MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader,
32    Qwen2_5VLLoader, Qwen3VLLoader, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
33    VisionModelLoader,
34};
35
36pub use diffusion_loaders::{
37    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
38    DiffusionModelPathsInner, FluxLoader,
39};
40
41use crate::{
42    matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
43    DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
44};
45
46use super::{paths::AdapterPaths, Pipeline};
47
48/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
49/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
50pub trait ModelPaths: AsAny + Debug + Send + Sync {
51    /// Model weights files (multiple files supported).
52    fn get_weight_filenames(&self) -> &[PathBuf];
53
54    /// Retrieve the [`PretrainedConfig`] file.
55    ///
56    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
57    fn get_config_filename(&self) -> &PathBuf;
58
59    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
60    ///
61    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
62    fn get_tokenizer_filename(&self) -> &PathBuf;
63
64    /// File where the content is expected to deserialize to [`ChatTemplate`].
65    ///
66    /// [`ChatTemplate`]: crate::ChatTemplate
67    fn get_template_filename(&self) -> &Option<PathBuf>;
68
69    /// Filepath for general model configuration.
70    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
71
72    /// Get the preprocessor config (for the vision models). This is used to pre process images.
73    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
74
75    /// Get the processor config (for the vision models). This is primarily used for the chat template.
76    fn get_processor_config(&self) -> &Option<PathBuf>;
77
78    /// Get the explicit chat template.
79    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
80
81    /// Get adapter paths.
82    fn get_adapter_paths(&self) -> &AdapterPaths;
83}
84
85#[derive(Clone, Debug)]
86/// All local paths and metadata necessary to load a model.
87pub struct LocalModelPaths<P: Debug> {
88    pub tokenizer_filename: P,
89    pub config_filename: P,
90    pub template_filename: Option<P>,
91    pub filenames: Vec<P>,
92    pub adapter_paths: AdapterPaths,
93    pub gen_conf: Option<P>,
94    pub preprocessor_config: Option<P>,
95    pub processor_config: Option<P>,
96    pub chat_template_json_filename: Option<P>,
97}
98
99impl<P: Debug> LocalModelPaths<P> {
100    #[allow(clippy::too_many_arguments)]
101    pub fn new(
102        tokenizer_filename: P,
103        config_filename: P,
104        template_filename: P,
105        filenames: Vec<P>,
106        adapter_paths: AdapterPaths,
107        gen_conf: Option<P>,
108        preprocessor_config: Option<P>,
109        processor_config: Option<P>,
110        chat_template_json_filename: Option<P>,
111    ) -> Self {
112        Self {
113            tokenizer_filename,
114            config_filename,
115            template_filename: Some(template_filename),
116            filenames,
117            adapter_paths,
118            gen_conf,
119            preprocessor_config,
120            processor_config,
121            chat_template_json_filename,
122        }
123    }
124}
125
126impl ModelPaths for LocalModelPaths<PathBuf> {
127    fn get_config_filename(&self) -> &PathBuf {
128        &self.config_filename
129    }
130    fn get_tokenizer_filename(&self) -> &PathBuf {
131        &self.tokenizer_filename
132    }
133    fn get_weight_filenames(&self) -> &[PathBuf] {
134        &self.filenames
135    }
136    fn get_template_filename(&self) -> &Option<PathBuf> {
137        &self.template_filename
138    }
139    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
140        self.gen_conf.as_ref()
141    }
142    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
143        &self.preprocessor_config
144    }
145    fn get_processor_config(&self) -> &Option<PathBuf> {
146        &self.processor_config
147    }
148    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
149        &self.chat_template_json_filename
150    }
151    fn get_adapter_paths(&self) -> &AdapterPaths {
152        &self.adapter_paths
153    }
154}
155
156#[derive(Debug, Clone)]
157/// The source of the HF token.
158pub enum TokenSource {
159    Literal(String),
160    EnvVar(String),
161    Path(String),
162    CacheToken,
163    None,
164}
165
166impl FromStr for TokenSource {
167    type Err = String;
168
169    fn from_str(s: &str) -> Result<Self, Self::Err> {
170        let parts: Vec<&str> = s.splitn(2, ':').collect();
171        match parts[0] {
172            "literal" => parts
173                .get(1)
174                .map(|&value| TokenSource::Literal(value.to_string()))
175                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
176            "env" => Ok(TokenSource::EnvVar(
177                parts
178                    .get(1)
179                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
180                    .to_string(),
181            )),
182            "path" => parts
183                .get(1)
184                .map(|&value| TokenSource::Path(value.to_string()))
185                .ok_or_else(|| "Expected a value for 'path'".to_string()),
186            "cache" => Ok(TokenSource::CacheToken),
187            "none" => Ok(TokenSource::None),
188            _ => Err("Invalid token source format".to_string()),
189        }
190    }
191}
192
193impl fmt::Display for TokenSource {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        match self {
196            TokenSource::Literal(value) => write!(f, "literal:{value}"),
197            TokenSource::EnvVar(value) => write!(f, "env:{value}"),
198            TokenSource::Path(value) => write!(f, "path:{value}"),
199            TokenSource::CacheToken => write!(f, "cache"),
200            TokenSource::None => write!(f, "none"),
201        }
202    }
203}
204
205/// The kind of model to build.
206#[derive(Clone, Default, derive_more::From, strum::Display)]
207pub enum ModelKind {
208    #[default]
209    #[strum(to_string = "normal (no adapters)")]
210    Normal,
211
212    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
213    GgufQuantized { quant: QuantizationKind },
214
215    #[strum(to_string = "{adapter}")]
216    Adapter { adapter: AdapterKind },
217
218    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
219    GgufAdapter {
220        adapter: AdapterKind,
221        quant: QuantizationKind,
222    },
223
224    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
225    Speculative {
226        target: Box<ModelKind>,
227        draft: Box<ModelKind>,
228    },
229
230    #[strum(to_string = "anymoe: target: `{target}`")]
231    AnyMoe { target: Box<ModelKind> },
232}
233
234#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
235#[strum(serialize_all = "kebab-case")]
236pub enum QuantizationKind {
237    /// GGML
238    Ggml,
239    /// GGUF
240    Gguf,
241    /// GPTQ
242    Gptq,
243}
244
245#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
246#[strum(serialize_all = "kebab-case")]
247pub enum AdapterKind {
248    /// LoRA
249    Lora,
250    /// X-LoRA
251    XLora,
252}
253
254// For the proper name as formatted via doc comment for a variant
255pub trait PrettyName: strum::EnumMessage + ToString {
256    fn pretty_name(&self) -> String {
257        match self.get_documentation() {
258            Some(s) => s.to_string(),
259            // Instead of panic via expect(),
260            // fallback to default kebab-case:
261            None => self.to_string(),
262        }
263    }
264}
265
266impl PrettyName for AdapterKind {}
267impl PrettyName for QuantizationKind {}
268
269impl ModelKind {
270    // Quantized helpers:
271    pub fn is_quantized(&self) -> bool {
272        self.quantized_kind().iter().any(|q| q.is_some())
273    }
274
275    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
276        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
277    }
278
279    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
280        use ModelKind::*;
281
282        match self {
283            Normal | Adapter { .. } => vec![None],
284            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
285            Speculative { target, draft } => {
286                let t = *target.clone();
287                let d = *draft.clone();
288
289                [t.quantized_kind(), d.quantized_kind()].concat()
290            }
291            AnyMoe { target } => target.quantized_kind(),
292        }
293    }
294
295    // Adapter helpers:
296    pub fn is_adapted(&self) -> bool {
297        self.adapted_kind().iter().any(|a| a.is_some())
298    }
299
300    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
301        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
302    }
303
304    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
305        use ModelKind::*;
306
307        match self {
308            Normal | GgufQuantized { .. } => vec![None],
309            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
310            Speculative { target, draft } => {
311                let t = *target.clone();
312                let d = *draft.clone();
313
314                [t.adapted_kind(), d.adapted_kind()].concat()
315            }
316            AnyMoe { target } => target.adapted_kind(),
317        }
318    }
319}
320
321#[derive(Deserialize)]
322pub struct QuantizationConfigShim {
323    quantization_config: Option<QuantizedConfig>,
324}
325
326impl QuantizationConfigShim {
327    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
328        let QuantizationConfigShim {
329            quantization_config,
330        } = serde_json::from_str(config)?;
331
332        if let Some(quantization_config) = quantization_config {
333            Ok(quantization_config.pack_factor(dtype))
334        } else {
335            Ok(1)
336        }
337    }
338}
339
340pub trait DeviceMappedModelLoader {
341    /// Maximum activation size of non-mapped parts of this model.
342    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
343    fn non_mapped_max_act_size_elems(
344        &self,
345        config: &str,
346        params: &AutoDeviceMapParams,
347    ) -> Result<usize>;
348    /// Maximum activation size of mapped parts of the model
349    fn mapped_max_act_size_elems(
350        &self,
351        config: &str,
352        params: &AutoDeviceMapParams,
353    ) -> Result<usize>;
354    /// weight_pack_factor only applies to quantized weights.
355    fn non_mapped_size_in_bytes(
356        &self,
357        config: &str,
358        dtype: DType,
359        weight_pack_factor: usize,
360        matformer_config: Option<&MatformerSliceConfig>,
361    ) -> Result<usize>;
362    /// weight_pack_factor only applies to quantized weights.
363    fn layer_sizes_in_bytes(
364        &self,
365        config: &str,
366        dtype: DType,
367        weight_pack_factor: usize,
368        matformer_config: Option<&MatformerSliceConfig>,
369    ) -> Result<Vec<usize>>;
370    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
371        None
372    }
373    fn num_layers(&self, config: &str) -> Result<usize>;
374    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
375
376    #[allow(clippy::too_many_arguments)]
377    fn get_device_layers(
378        &self,
379        config: &str,
380        num_layers: usize,
381        layer_sizes_in_bytes: Vec<usize>,
382        non_mapped_size_in_bytes: usize,
383        total_model_size_in_bytes: usize,
384        devices: &[Device],
385        dtype: DType,
386        params: &AutoDeviceMapParams,
387        paged_attn_config: Option<&PagedAttentionConfig>,
388    ) -> Result<DeviceMapMetadata>
389    where
390        Self: Sized,
391    {
392        auto_device_map::get_device_layers(
393            self,
394            config,
395            num_layers,
396            layer_sizes_in_bytes,
397            non_mapped_size_in_bytes,
398            total_model_size_in_bytes,
399            devices,
400            dtype,
401            params,
402            paged_attn_config,
403        )
404    }
405}
406
407/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
408/// `load_model` method.
409///
410/// # Example
411/// ```no_run
412/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
413/// use candle_core::Device;
414///
415/// let loader: Box<dyn Loader> = todo!();
416/// let pipeline = loader.load_model_from_hf(
417///     None,
418///     TokenSource::CacheToken,
419///     &ModelDType::Auto,
420///     &Device::cuda_if_available(0).unwrap(),
421///     false,
422///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
423///     None,
424///     None,
425/// ).unwrap();
426/// ```
427pub trait Loader: Send + Sync {
428    /// If `revision` is None, then it defaults to `main`.
429    /// If `dtype` is None, then it defaults to the model default (usually BF16).
430    /// If model is not found on HF, will attempt to resolve locally.
431    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
432    fn load_model_from_hf(
433        &self,
434        revision: Option<String>,
435        token_source: TokenSource,
436        dtype: &dyn TryIntoDType,
437        device: &Device,
438        silent: bool,
439        mapper: DeviceMapSetting,
440        in_situ_quant: Option<IsqType>,
441        paged_attn_config: Option<PagedAttentionConfig>,
442    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
443
444    /// Load a model from the specified paths.
445    /// Also initializes `DEBUG`.
446    #[allow(
447        clippy::type_complexity,
448        clippy::too_many_arguments,
449        clippy::borrowed_box
450    )]
451    fn load_model_from_path(
452        &self,
453        paths: &Box<dyn ModelPaths>,
454        dtype: &dyn TryIntoDType,
455        device: &Device,
456        silent: bool,
457        mapper: DeviceMapSetting,
458        in_situ_quant: Option<IsqType>,
459        paged_attn_config: Option<PagedAttentionConfig>,
460    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
461
462    fn get_id(&self) -> String;
463    fn get_kind(&self) -> ModelKind;
464}