mistralrs_core/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod normal_loaders;
4mod vision_loaders;
5pub use auto_device_map::AutoDeviceMapParams;
6use auto_device_map::NonMappedSubModel;
7
8use std::{
9    fmt::{self, Debug},
10    path::PathBuf,
11    str::FromStr,
12    sync::Arc,
13};
14
15use anyhow::Result;
16use as_any::AsAny;
17use candle_core::{DType, Device};
18use mistralrs_quant::{IsqType, QuantizedConfig};
19use serde::Deserialize;
20use tokio::sync::Mutex;
21
22pub use normal_loaders::{
23    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
24    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
25    NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
26    Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
27};
28
29pub use vision_loaders::{
30    AutoVisionLoader, Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
31    MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader,
32    VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
33};
34
35pub use diffusion_loaders::{
36    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
37    DiffusionModelPathsInner, FluxLoader,
38};
39
40use crate::{
41    paged_attention::ModelConfigLike, DeviceMapMetadata, DeviceMapSetting, PagedAttentionConfig,
42    TryIntoDType,
43};
44
45use super::{paths::AdapterPaths, Pipeline};
46
47/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
48/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
49pub trait ModelPaths: AsAny + Debug + Send + Sync {
50    /// Model weights files (multiple files supported).
51    fn get_weight_filenames(&self) -> &[PathBuf];
52
53    /// Retrieve the [`PretrainedConfig`] file.
54    ///
55    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
56    fn get_config_filename(&self) -> &PathBuf;
57
58    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
59    ///
60    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
61    fn get_tokenizer_filename(&self) -> &PathBuf;
62
63    /// File where the content is expected to deserialize to [`ChatTemplate`].
64    ///
65    /// [`ChatTemplate`]: crate::ChatTemplate
66    fn get_template_filename(&self) -> &Option<PathBuf>;
67
68    /// Filepath for general model configuration.
69    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
70
71    /// Get the preprocessor config (for the vision models). This is used to pre process images.
72    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
73
74    /// Get the processor config (for the vision models). This is primarily used for the chat template.
75    fn get_processor_config(&self) -> &Option<PathBuf>;
76
77    /// Get the explicit chat template.
78    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
79
80    /// Get adapter paths.
81    fn get_adapter_paths(&self) -> &AdapterPaths;
82}
83
84#[derive(Clone, Debug)]
85/// All local paths and metadata necessary to load a model.
86pub struct LocalModelPaths<P: Debug> {
87    pub tokenizer_filename: P,
88    pub config_filename: P,
89    pub template_filename: Option<P>,
90    pub filenames: Vec<P>,
91    pub adapter_paths: AdapterPaths,
92    pub gen_conf: Option<P>,
93    pub preprocessor_config: Option<P>,
94    pub processor_config: Option<P>,
95    pub chat_template_json_filename: Option<P>,
96}
97
98impl<P: Debug> LocalModelPaths<P> {
99    #[allow(clippy::too_many_arguments)]
100    pub fn new(
101        tokenizer_filename: P,
102        config_filename: P,
103        template_filename: P,
104        filenames: Vec<P>,
105        adapter_paths: AdapterPaths,
106        gen_conf: Option<P>,
107        preprocessor_config: Option<P>,
108        processor_config: Option<P>,
109        chat_template_json_filename: Option<P>,
110    ) -> Self {
111        Self {
112            tokenizer_filename,
113            config_filename,
114            template_filename: Some(template_filename),
115            filenames,
116            adapter_paths,
117            gen_conf,
118            preprocessor_config,
119            processor_config,
120            chat_template_json_filename,
121        }
122    }
123}
124
125impl ModelPaths for LocalModelPaths<PathBuf> {
126    fn get_config_filename(&self) -> &PathBuf {
127        &self.config_filename
128    }
129    fn get_tokenizer_filename(&self) -> &PathBuf {
130        &self.tokenizer_filename
131    }
132    fn get_weight_filenames(&self) -> &[PathBuf] {
133        &self.filenames
134    }
135    fn get_template_filename(&self) -> &Option<PathBuf> {
136        &self.template_filename
137    }
138    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
139        self.gen_conf.as_ref()
140    }
141    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
142        &self.preprocessor_config
143    }
144    fn get_processor_config(&self) -> &Option<PathBuf> {
145        &self.processor_config
146    }
147    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
148        &self.chat_template_json_filename
149    }
150    fn get_adapter_paths(&self) -> &AdapterPaths {
151        &self.adapter_paths
152    }
153}
154
155#[derive(Debug, Clone)]
156/// The source of the HF token.
157pub enum TokenSource {
158    Literal(String),
159    EnvVar(String),
160    Path(String),
161    CacheToken,
162    None,
163}
164
165impl FromStr for TokenSource {
166    type Err = String;
167
168    fn from_str(s: &str) -> Result<Self, Self::Err> {
169        let parts: Vec<&str> = s.splitn(2, ':').collect();
170        match parts[0] {
171            "literal" => parts
172                .get(1)
173                .map(|&value| TokenSource::Literal(value.to_string()))
174                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
175            "env" => Ok(TokenSource::EnvVar(
176                parts
177                    .get(1)
178                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
179                    .to_string(),
180            )),
181            "path" => parts
182                .get(1)
183                .map(|&value| TokenSource::Path(value.to_string()))
184                .ok_or_else(|| "Expected a value for 'path'".to_string()),
185            "cache" => Ok(TokenSource::CacheToken),
186            "none" => Ok(TokenSource::None),
187            _ => Err("Invalid token source format".to_string()),
188        }
189    }
190}
191
192impl fmt::Display for TokenSource {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        match self {
195            TokenSource::Literal(value) => write!(f, "literal:{}", value),
196            TokenSource::EnvVar(value) => write!(f, "env:{}", value),
197            TokenSource::Path(value) => write!(f, "path:{}", value),
198            TokenSource::CacheToken => write!(f, "cache"),
199            TokenSource::None => write!(f, "none"),
200        }
201    }
202}
203
204/// The kind of model to build.
205#[derive(Clone, Default, derive_more::From, strum::Display)]
206pub enum ModelKind {
207    #[default]
208    #[strum(to_string = "normal (no adapters)")]
209    Normal,
210
211    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
212    GgufQuantized { quant: QuantizationKind },
213
214    #[strum(to_string = "{adapter}")]
215    Adapter { adapter: AdapterKind },
216
217    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
218    GgufAdapter {
219        adapter: AdapterKind,
220        quant: QuantizationKind,
221    },
222
223    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
224    Speculative {
225        target: Box<ModelKind>,
226        draft: Box<ModelKind>,
227    },
228
229    #[strum(to_string = "anymoe: target: `{target}`")]
230    AnyMoe { target: Box<ModelKind> },
231}
232
233#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
234#[strum(serialize_all = "kebab-case")]
235pub enum QuantizationKind {
236    /// GGML
237    Ggml,
238    /// GGUF
239    Gguf,
240    /// GPTQ
241    Gptq,
242}
243
244#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
245#[strum(serialize_all = "kebab-case")]
246pub enum AdapterKind {
247    /// LoRA
248    Lora,
249    /// X-LoRA
250    XLora,
251}
252
253// For the proper name as formatted via doc comment for a variant
254pub trait PrettyName: strum::EnumMessage + ToString {
255    fn pretty_name(&self) -> String {
256        match self.get_documentation() {
257            Some(s) => s.to_string(),
258            // Instead of panic via expect(),
259            // fallback to default kebab-case:
260            None => self.to_string(),
261        }
262    }
263}
264
265impl PrettyName for AdapterKind {}
266impl PrettyName for QuantizationKind {}
267
268impl ModelKind {
269    // Quantized helpers:
270    pub fn is_quantized(&self) -> bool {
271        self.quantized_kind().iter().any(|q| q.is_some())
272    }
273
274    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
275        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
276    }
277
278    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
279        use ModelKind::*;
280
281        match self {
282            Normal | Adapter { .. } => vec![None],
283            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
284            Speculative { target, draft } => {
285                let t = *target.clone();
286                let d = *draft.clone();
287
288                [t.quantized_kind(), d.quantized_kind()].concat()
289            }
290            AnyMoe { target } => target.quantized_kind(),
291        }
292    }
293
294    // Adapter helpers:
295    pub fn is_adapted(&self) -> bool {
296        self.adapted_kind().iter().any(|a| a.is_some())
297    }
298
299    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
300        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
301    }
302
303    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
304        use ModelKind::*;
305
306        match self {
307            Normal | GgufQuantized { .. } => vec![None],
308            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
309            Speculative { target, draft } => {
310                let t = *target.clone();
311                let d = *draft.clone();
312
313                [t.adapted_kind(), d.adapted_kind()].concat()
314            }
315            AnyMoe { target } => target.adapted_kind(),
316        }
317    }
318}
319
320#[derive(Deserialize)]
321pub struct QuantizationConfigShim {
322    quantization_config: Option<QuantizedConfig>,
323}
324
325impl QuantizationConfigShim {
326    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
327        let QuantizationConfigShim {
328            quantization_config,
329        } = serde_json::from_str(config)?;
330
331        if let Some(quantization_config) = quantization_config {
332            Ok(quantization_config.pack_factor(dtype))
333        } else {
334            Ok(1)
335        }
336    }
337}
338
339pub trait DeviceMappedModelLoader {
340    /// Maximum activation size of non-mapped parts of this model.
341    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
342    fn non_mapped_max_act_size_elems(
343        &self,
344        config: &str,
345        params: &AutoDeviceMapParams,
346    ) -> Result<usize>;
347    /// Maximum activation size of mapped parts of the model
348    fn mapped_max_act_size_elems(
349        &self,
350        config: &str,
351        params: &AutoDeviceMapParams,
352        prompt_chunksize: usize,
353    ) -> Result<usize>;
354    /// weight_pack_factor only applies to quantized weights.
355    fn non_mapped_size_in_bytes(
356        &self,
357        config: &str,
358        dtype: DType,
359        weight_pack_factor: usize,
360    ) -> Result<usize>;
361    /// weight_pack_factor only applies to quantized weights.
362    fn layer_sizes_in_bytes(
363        &self,
364        config: &str,
365        dtype: DType,
366        weight_pack_factor: usize,
367    ) -> Result<Vec<usize>>;
368    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
369        None
370    }
371    fn num_layers(&self, config: &str) -> Result<usize>;
372    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
373
374    #[allow(clippy::too_many_arguments)]
375    fn get_device_layers(
376        &self,
377        config: &str,
378        num_layers: usize,
379        layer_sizes_in_bytes: Vec<usize>,
380        non_mapped_size_in_bytes: usize,
381        total_model_size_in_bytes: usize,
382        devices: &[Device],
383        dtype: DType,
384        params: &AutoDeviceMapParams,
385        prompt_chunksize: usize,
386        paged_attn_config: Option<&PagedAttentionConfig>,
387    ) -> Result<DeviceMapMetadata>
388    where
389        Self: Sized,
390    {
391        auto_device_map::get_device_layers(
392            self,
393            config,
394            num_layers,
395            layer_sizes_in_bytes,
396            non_mapped_size_in_bytes,
397            total_model_size_in_bytes,
398            devices,
399            dtype,
400            params,
401            prompt_chunksize,
402            paged_attn_config,
403        )
404    }
405}
406
407/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
408/// `load_model` method.
409///
410/// # Example
411/// ```no_run
412/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
413/// use candle_core::Device;
414///
415/// let loader: Box<dyn Loader> = todo!();
416/// let pipeline = loader.load_model_from_hf(
417///     None,
418///     TokenSource::CacheToken,
419///     &ModelDType::Auto,
420///     &Device::cuda_if_available(0).unwrap(),
421///     false,
422///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
423///     None,
424///     None,
425/// ).unwrap();
426/// ```
427pub trait Loader: Send + Sync {
428    /// If `revision` is None, then it defaults to `main`.
429    /// If `dtype` is None, then it defaults to the model default (usually BF16).
430    /// If model is not found on HF, will attempt to resolve locally.
431    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
432    fn load_model_from_hf(
433        &self,
434        revision: Option<String>,
435        token_source: TokenSource,
436        dtype: &dyn TryIntoDType,
437        device: &Device,
438        silent: bool,
439        mapper: DeviceMapSetting,
440        in_situ_quant: Option<IsqType>,
441        paged_attn_config: Option<PagedAttentionConfig>,
442    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
443
444    /// Load a model from the specified paths.
445    /// Also initializes `DEBUG`.
446    #[allow(
447        clippy::type_complexity,
448        clippy::too_many_arguments,
449        clippy::borrowed_box
450    )]
451    fn load_model_from_path(
452        &self,
453        paths: &Box<dyn ModelPaths>,
454        dtype: &dyn TryIntoDType,
455        device: &Device,
456        silent: bool,
457        mapper: DeviceMapSetting,
458        in_situ_quant: Option<IsqType>,
459        paged_attn_config: Option<PagedAttentionConfig>,
460    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
461
462    fn get_id(&self) -> String;
463    fn get_kind(&self) -> ModelKind;
464}