mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "qwen3moe")]
175    Qwen3Moe,
176    #[serde(rename = "smollm3")]
177    SmolLm3,
178    #[serde(rename = "granitemoehybrid")]
179    GraniteMoeHybrid,
180}
181
182// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
183impl NormalLoaderType {
184    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
185        match name {
186            "MistralForCausalLM" => Ok(Self::Mistral),
187            "MixtralForCausalLM" => Ok(Self::Mixtral),
188            "GemmaForCausalLM" => Ok(Self::Gemma),
189            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
190            "PhiForCausalLM" => Ok(Self::Phi2),
191            "Phi3ForCausalLM" => Ok(Self::Phi3),
192            "LlamaForCausalLM" => Ok(Self::Llama),
193            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
194            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
195            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
196            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
197            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
198            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
199            "Glm4ForCausalLM" => Ok(Self::GLM4),
200            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
201            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
202            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
203            other => anyhow::bail!(
204                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
205            ),
206        }
207    }
208}
209
210impl FromStr for NormalLoaderType {
211    type Err = String;
212    fn from_str(s: &str) -> Result<Self, Self::Err> {
213        match s {
214            "mistral" => Ok(Self::Mistral),
215            "gemma" => Ok(Self::Gemma),
216            "mixtral" => Ok(Self::Mixtral),
217            "llama" => Ok(Self::Llama),
218            "phi2" => Ok(Self::Phi2),
219            "phi3" => Ok(Self::Phi3),
220            "qwen2" => Ok(Self::Qwen2),
221            "gemma2" => Ok(Self::Gemma2),
222            "starcoder2" => Ok(Self::Starcoder2),
223            "phi3.5moe" => Ok(Self::Phi3_5MoE),
224            "deepseekv2" => Ok(Self::DeepSeekV2),
225            "deepseekv3" => Ok(Self::DeepSeekV3),
226            "qwen3" => Ok(Self::Qwen3),
227            "glm4" => Ok(Self::GLM4),
228            "qwen3moe" => Ok(Self::Qwen3Moe),
229            "smollm3" => Ok(Self::SmolLm3),
230            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
231            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`, `granitemoehybrid`.")),
232        }
233    }
234}
235
236impl Display for NormalLoaderType {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            Self::Gemma => write!(f, "gemma"),
240            Self::Gemma2 => write!(f, "gemma2"),
241            Self::Llama => write!(f, "llama"),
242            Self::Mistral => write!(f, "mistral"),
243            Self::Mixtral => write!(f, "mixtral"),
244            Self::Phi2 => write!(f, "phi2"),
245            Self::Phi3 => write!(f, "phi3"),
246            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
247            Self::Qwen2 => write!(f, "qwen2"),
248            Self::Starcoder2 => write!(f, "starcoder2"),
249            Self::DeepSeekV2 => write!(f, "deepseekv2"),
250            Self::DeepSeekV3 => write!(f, "deepseekv3"),
251            Self::Qwen3 => write!(f, "qwen3"),
252            Self::GLM4 => write!(f, "glm4"),
253            Self::Qwen3Moe => write!(f, "qwen3moe"),
254            Self::SmolLm3 => write!(f, "smollm3"),
255            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
256        }
257    }
258}
259
260macro_rules! bias_if {
261    ($cond:expr, $size:expr) => {
262        if $cond {
263            $size
264        } else {
265            0
266        }
267    };
268}
269
270/// Load a model based on the Hugging Face Transformers -CausalLM model class
271pub struct AutoNormalLoader;
272
273#[derive(Deserialize)]
274struct AutoNormalLoaderConfig {
275    architectures: Vec<String>,
276}
277
278impl AutoNormalLoader {
279    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
280        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
281        if auto_cfg.architectures.len() != 1 {
282            anyhow::bail!("Expected to have one name for `architectures` config field.")
283        }
284
285        let name = &auto_cfg.architectures[0];
286
287        let tp = NormalLoaderType::from_causal_lm_name(name)?;
288
289        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
290
291        match tp {
292            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
293            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
294            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
295            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
296            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
297            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
298            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
299            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
300            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
301            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
302            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
303            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
304            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
305            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
306            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
307            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
308            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
309        }
310    }
311}
312
313impl NormalModelLoader for AutoNormalLoader {
314    fn load(
315        &self,
316        config: &str,
317        vb: ShardedVarBuilder,
318        normal_loading_metadata: NormalLoadingMetadata,
319        attention_mechanism: AttentionImplementation,
320    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
321        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
322    }
323    fn load_xlora(
324        &self,
325        config: &str,
326        vb: ShardedVarBuilder,
327        lora_config: &[((String, String), LoraConfig)],
328        xlora_config: Option<XLoraConfig>,
329        xlora_ordering: Ordering,
330        normal_loading_metadata: NormalLoadingMetadata,
331        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
332    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
333        Self::get_loader(config)?.load_xlora(
334            config,
335            vb,
336            lora_config,
337            xlora_config,
338            xlora_ordering,
339            normal_loading_metadata,
340            preload_adapters,
341        )
342    }
343    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
344        Self::get_loader(config)?.get_config_repr(config)
345    }
346    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
347        Self::get_loader(config)?.supports_paged_attention(config)
348    }
349    fn is_gptx(&self, config: &str) -> Result<bool> {
350        Self::get_loader(config)?.is_gptx(config)
351    }
352}
353
354impl IsqModelLoader for AutoNormalLoader {
355    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
356        Self::get_loader(config)?.immediate_isq_predicates(config)
357    }
358    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
360    }
361    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
362        Self::get_loader(config)?.isq_layer_regexes(config)
363    }
364    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
365        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
366    }
367}
368
369impl DeviceMappedModelLoader for AutoNormalLoader {
370    fn non_mapped_size_in_bytes(
371        &self,
372        config: &str,
373        dtype: DType,
374        weight_pack_factor: usize,
375        _matformer_config: Option<&MatformerSliceConfig>,
376    ) -> Result<usize> {
377        Self::get_loader(config)?.non_mapped_size_in_bytes(
378            config,
379            dtype,
380            weight_pack_factor,
381            _matformer_config,
382        )
383    }
384    fn num_layers(&self, config: &str) -> Result<usize> {
385        Self::get_loader(config)?.num_layers(config)
386    }
387    fn layer_sizes_in_bytes(
388        &self,
389        config: &str,
390        dtype: DType,
391        weight_pack_factor: usize,
392        _matformer_config: Option<&MatformerSliceConfig>,
393    ) -> Result<Vec<usize>> {
394        Self::get_loader(config)?.layer_sizes_in_bytes(
395            config,
396            dtype,
397            weight_pack_factor,
398            _matformer_config,
399        )
400    }
401    fn mapped_max_act_size_elems(
402        &self,
403        config: &str,
404        params: &super::AutoDeviceMapParams,
405    ) -> Result<usize> {
406        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
407    }
408    fn non_mapped_max_act_size_elems(
409        &self,
410        _config: &str,
411        _params: &AutoDeviceMapParams,
412    ) -> Result<usize> {
413        Ok(0)
414    }
415    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
416        Self::get_loader(config)?.model_config(config)
417    }
418}
419
420// ======================== Mistral loader
421
422pub struct MistralLoader;
423
424impl NormalModelLoader for MistralLoader {
425    fn load(
426        &self,
427        config: &str,
428        vb: ShardedVarBuilder,
429        normal_loading_metadata: NormalLoadingMetadata,
430        attention_mechanism: AttentionImplementation,
431    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
432        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
433        Ok(Box::new(models::mistral::Model::new(
434            &cfg,
435            vb,
436            self.is_gptx(config)?,
437            normal_loading_metadata,
438            attention_mechanism,
439        )?))
440    }
441    fn load_xlora(
442        &self,
443        config: &str,
444        vb: ShardedVarBuilder,
445        lora_config: &[((String, String), LoraConfig)],
446        xlora_config: Option<XLoraConfig>,
447        xlora_ordering: Ordering,
448        normal_loading_metadata: NormalLoadingMetadata,
449        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
450    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
451        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
452        Ok(Box::new(xlora_models::XLoraMistral::new(
453            &cfg,
454            vb,
455            lora_config,
456            xlora_config,
457            xlora_ordering,
458            self.is_gptx(config)?,
459            normal_loading_metadata,
460            preload_adapters,
461        )?))
462    }
463    fn is_gptx(&self, _: &str) -> Result<bool> {
464        Ok(true)
465    }
466    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
467        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
468        Ok(Box::new(cfg))
469    }
470}
471
472impl IsqModelLoader for MistralLoader {
473    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
474        Ok(vec![
475            Regex::new(r"lm_head\.(weight|bias)$")?,
476            // Attention
477            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
478            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
479            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
480            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
481            // MLP
482            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
483            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
484            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
485        ])
486    }
487    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
488        self.isq_layer_regexes(config)
489    }
490}
491
492impl DeviceMappedModelLoader for MistralLoader {
493    fn mapped_max_act_size_elems(
494        &self,
495        config: &str,
496        params: &AutoDeviceMapParams,
497    ) -> Result<usize> {
498        let AutoDeviceMapParams::Text {
499            max_seq_len,
500            max_batch_size,
501        } = params
502        else {
503            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
504        };
505
506        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
507
508        Ok(
509            max_batch_size
510                * cfg.num_attention_heads
511                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
512        )
513    }
514    fn non_mapped_max_act_size_elems(
515        &self,
516        _config: &str,
517        _params: &AutoDeviceMapParams,
518    ) -> Result<usize> {
519        Ok(0)
520    }
521
522    fn non_mapped_size_in_bytes(
523        &self,
524        config: &str,
525        dtype: DType,
526        weight_pack_factor: usize,
527        _matformer_config: Option<&MatformerSliceConfig>,
528    ) -> Result<usize> {
529        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
530
531        let elems = {
532            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
533            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
534            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
535                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
536            } else {
537                0
538            };
539            let norm = cfg.hidden_size;
540            embed_tokens + lm_head + norm
541        };
542        Ok(elems * dtype.size_in_bytes())
543    }
544
545    fn layer_sizes_in_bytes(
546        &self,
547        config: &str,
548        dtype: DType,
549        weight_pack_factor: usize,
550        _matformer_config: Option<&MatformerSliceConfig>,
551    ) -> Result<Vec<usize>> {
552        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
553
554        let per_layer_elems = {
555            let input_layernorm = cfg.hidden_size;
556            let post_attention_layernorm = cfg.hidden_size;
557
558            let size_in = cfg.hidden_size;
559            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
560            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
561            let q_proj = size_in * size_q / weight_pack_factor;
562            let k_proj = size_in * size_kv / weight_pack_factor;
563            let v_proj = size_in * size_kv / weight_pack_factor;
564            let o_proj = size_q * size_in / weight_pack_factor;
565
566            let h_size = cfg.hidden_size;
567            let i_size = cfg.intermediate_size;
568            let gate_proj = h_size * i_size / weight_pack_factor;
569            let up_proj = h_size * i_size / weight_pack_factor;
570            let down_proj = i_size * h_size / weight_pack_factor;
571
572            input_layernorm
573                + post_attention_layernorm
574                + q_proj
575                + k_proj
576                + v_proj
577                + o_proj
578                + gate_proj
579                + up_proj
580                + down_proj
581        };
582        Ok(vec![
583            per_layer_elems * dtype.size_in_bytes();
584            cfg.num_hidden_layers
585        ])
586    }
587
588    fn num_layers(&self, config: &str) -> Result<usize> {
589        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
590        Ok(cfg.num_hidden_layers)
591    }
592
593    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
594        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
595
596        let cfg = ModelConfigMetadata {
597            max_seq_len: cfg.max_position_embeddings,
598            num_layers: cfg.num_hidden_layers,
599            hidden_size: cfg.hidden_size,
600            num_kv_heads: cfg.num_key_value_heads,
601            num_attn_heads: cfg.num_attention_heads,
602            sliding_window: cfg.sliding_window,
603            k_head_dim: cfg.head_dim(),
604            v_head_dim: cfg.head_dim(),
605        };
606
607        Ok(Box::new(cfg))
608    }
609}
610
611// ======================== Gemma loader
612
613/// [`NormalLoader`] for a Gemma model.
614///
615/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
616pub struct GemmaLoader;
617
618impl NormalModelLoader for GemmaLoader {
619    fn load(
620        &self,
621        config: &str,
622        vb: ShardedVarBuilder,
623        normal_loading_metadata: NormalLoadingMetadata,
624        attention_mechanism: AttentionImplementation,
625    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
626        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
627
628        Ok(Box::new(models::gemma::Model::new(
629            &cfg,
630            vb,
631            self.is_gptx(config)?,
632            normal_loading_metadata,
633            attention_mechanism,
634        )?))
635    }
636    fn load_xlora(
637        &self,
638        config: &str,
639        vb: ShardedVarBuilder,
640        lora_config: &[((String, String), LoraConfig)],
641        xlora_config: Option<XLoraConfig>,
642        xlora_ordering: Ordering,
643        normal_loading_metadata: NormalLoadingMetadata,
644        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
645    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
646        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
647
648        Ok(Box::new(xlora_models::XLoraGemma::new(
649            &cfg,
650            vb,
651            lora_config,
652            xlora_config,
653            xlora_ordering,
654            self.is_gptx(config)?,
655            normal_loading_metadata,
656            preload_adapters,
657        )?))
658    }
659    fn is_gptx(&self, _: &str) -> Result<bool> {
660        Ok(true)
661    }
662    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
663        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
664        Ok(Box::new(cfg))
665    }
666}
667
668impl IsqModelLoader for GemmaLoader {
669    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
670        Ok(vec![
671            Regex::new(r"lm_head\.(weight|bias)$")?,
672            // Attention
673            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
674            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
675            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
676            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
677            // MLP
678            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
679            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
680            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
681        ])
682    }
683    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
684        self.isq_layer_regexes(config)
685    }
686}
687
688impl DeviceMappedModelLoader for GemmaLoader {
689    fn mapped_max_act_size_elems(
690        &self,
691        config: &str,
692        params: &AutoDeviceMapParams,
693    ) -> Result<usize> {
694        let AutoDeviceMapParams::Text {
695            max_seq_len,
696            max_batch_size,
697        } = params
698        else {
699            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
700        };
701
702        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
703
704        Ok(
705            max_batch_size
706                * cfg.num_attention_heads
707                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
708        )
709    }
710    fn non_mapped_max_act_size_elems(
711        &self,
712        _config: &str,
713        _params: &AutoDeviceMapParams,
714    ) -> Result<usize> {
715        Ok(0)
716    }
717
718    fn non_mapped_size_in_bytes(
719        &self,
720        config: &str,
721        dtype: DType,
722        weight_pack_factor: usize,
723        _matformer_config: Option<&MatformerSliceConfig>,
724    ) -> Result<usize> {
725        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
726
727        let elems = {
728            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
729            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
730            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
731                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
732            } else {
733                0
734            };
735            let norm = cfg.hidden_size;
736            embed_tokens + lm_head + norm
737        };
738        Ok(elems * dtype.size_in_bytes())
739    }
740
741    fn layer_sizes_in_bytes(
742        &self,
743        config: &str,
744        dtype: DType,
745        weight_pack_factor: usize,
746        _matformer_config: Option<&MatformerSliceConfig>,
747    ) -> Result<Vec<usize>> {
748        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
749
750        let per_layer_elems = {
751            let input_layernorm = cfg.hidden_size;
752            let post_attention_layernorm = cfg.hidden_size;
753
754            let size_in = cfg.hidden_size;
755            let size_q = cfg.head_dim * cfg.num_attention_heads;
756            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
757            let q_proj =
758                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
759            let k_proj =
760                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
761            let v_proj =
762                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
763            let o_proj =
764                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
765
766            let h_size = cfg.hidden_size;
767            let i_size = cfg.intermediate_size;
768            let gate_proj = h_size * i_size / weight_pack_factor;
769            let up_proj = h_size * i_size / weight_pack_factor;
770            let down_proj = i_size * h_size / weight_pack_factor;
771
772            input_layernorm
773                + post_attention_layernorm
774                + q_proj
775                + k_proj
776                + v_proj
777                + o_proj
778                + gate_proj
779                + up_proj
780                + down_proj
781        };
782        Ok(vec![
783            per_layer_elems * dtype.size_in_bytes();
784            cfg.num_hidden_layers
785        ])
786    }
787
788    fn num_layers(&self, config: &str) -> Result<usize> {
789        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
790        Ok(cfg.num_hidden_layers)
791    }
792
793    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
794        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
795
796        let cfg = ModelConfigMetadata {
797            max_seq_len: cfg.max_position_embeddings,
798            num_layers: cfg.num_hidden_layers,
799            hidden_size: cfg.hidden_size,
800            num_kv_heads: cfg.num_key_value_heads,
801            num_attn_heads: cfg.num_attention_heads,
802            sliding_window: None,
803            k_head_dim: cfg.head_dim,
804            v_head_dim: cfg.head_dim,
805        };
806
807        Ok(Box::new(cfg))
808    }
809}
810
811// ======================== Llama loader
812
813/// [`NormalLoader`] for a Llama model.
814///
815/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
816pub struct LlamaLoader;
817
818impl NormalModelLoader for LlamaLoader {
819    fn load(
820        &self,
821        config: &str,
822        vb: ShardedVarBuilder,
823        normal_loading_metadata: NormalLoadingMetadata,
824        attention_mechanism: AttentionImplementation,
825    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
826        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
827
828        Ok(Box::new(models::llama::Llama::new(
829            &cfg,
830            vb,
831            self.is_gptx(config)?,
832            normal_loading_metadata,
833            attention_mechanism,
834        )?))
835    }
836    fn load_xlora(
837        &self,
838        config: &str,
839        vb: ShardedVarBuilder,
840        lora_config: &[((String, String), LoraConfig)],
841        xlora_config: Option<XLoraConfig>,
842        xlora_ordering: Ordering,
843        normal_loading_metadata: NormalLoadingMetadata,
844        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
845    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
846        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
847
848        Ok(Box::new(xlora_models::XLoraLlama::new(
849            &cfg,
850            vb,
851            lora_config,
852            xlora_config,
853            xlora_ordering,
854            self.is_gptx(config)?,
855            normal_loading_metadata,
856            preload_adapters,
857        )?))
858    }
859    fn is_gptx(&self, _: &str) -> Result<bool> {
860        Ok(true)
861    }
862    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
863        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
864        Ok(Box::new(cfg))
865    }
866}
867
868impl IsqModelLoader for LlamaLoader {
869    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
870        Ok(vec![
871            Regex::new(r"lm_head\.(weight|bias)$")?,
872            // Attention
873            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
874            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
875            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
876            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
877            // MLP
878            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
879            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
880            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
881        ])
882    }
883    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
884        self.isq_layer_regexes(config)
885    }
886}
887
888impl DeviceMappedModelLoader for LlamaLoader {
889    fn mapped_max_act_size_elems(
890        &self,
891        config: &str,
892        params: &AutoDeviceMapParams,
893    ) -> Result<usize> {
894        let AutoDeviceMapParams::Text {
895            max_seq_len,
896            max_batch_size,
897        } = params
898        else {
899            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
900        };
901
902        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
903
904        Ok(
905            max_batch_size
906                * cfg.num_attention_heads
907                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
908        )
909    }
910    fn non_mapped_max_act_size_elems(
911        &self,
912        _config: &str,
913        _params: &AutoDeviceMapParams,
914    ) -> Result<usize> {
915        Ok(0)
916    }
917
918    fn non_mapped_size_in_bytes(
919        &self,
920        config: &str,
921        dtype: DType,
922        weight_pack_factor: usize,
923        _matformer_config: Option<&MatformerSliceConfig>,
924    ) -> Result<usize> {
925        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
926
927        let elems = {
928            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
929            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
930            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
931                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
932            } else {
933                0
934            };
935            let norm = cfg.hidden_size;
936            embed_tokens + lm_head + norm
937        };
938        Ok(elems * dtype.size_in_bytes())
939    }
940
941    fn layer_sizes_in_bytes(
942        &self,
943        config: &str,
944        dtype: DType,
945        weight_pack_factor: usize,
946        _matformer_config: Option<&MatformerSliceConfig>,
947    ) -> Result<Vec<usize>> {
948        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
949
950        let per_layer_elems = {
951            let input_layernorm = cfg.hidden_size;
952            let post_attention_layernorm = cfg.hidden_size;
953
954            let size_in = cfg.hidden_size;
955            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
956            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
957            let q_proj = size_in * size_q / weight_pack_factor;
958            let k_proj = size_in * size_kv / weight_pack_factor;
959            let v_proj = size_in * size_kv / weight_pack_factor;
960            let o_proj = size_q * size_in / weight_pack_factor;
961
962            let h_size = cfg.hidden_size;
963            let i_size = cfg.intermediate_size;
964            let gate_proj = h_size * i_size / weight_pack_factor;
965            let up_proj = h_size * i_size / weight_pack_factor;
966            let down_proj = i_size * h_size / weight_pack_factor;
967
968            input_layernorm
969                + post_attention_layernorm
970                + q_proj
971                + k_proj
972                + v_proj
973                + o_proj
974                + gate_proj
975                + up_proj
976                + down_proj
977        };
978        Ok(vec![
979            per_layer_elems * dtype.size_in_bytes();
980            cfg.num_hidden_layers
981        ])
982    }
983
984    fn num_layers(&self, config: &str) -> Result<usize> {
985        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
986
987        Ok(cfg.num_hidden_layers)
988    }
989    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
990        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
991
992        let cfg = ModelConfigMetadata {
993            max_seq_len: cfg.max_position_embeddings,
994            num_layers: cfg.num_hidden_layers,
995            hidden_size: cfg.hidden_size,
996            num_kv_heads: cfg.num_key_value_heads,
997            num_attn_heads: cfg.num_attention_heads,
998            sliding_window: None,
999            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1000            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1001        };
1002
1003        Ok(Box::new(cfg))
1004    }
1005}
1006
1007// ======================== Mixtral loader
1008
1009pub struct MixtralLoader;
1010
1011impl NormalModelLoader for MixtralLoader {
1012    fn load(
1013        &self,
1014        config: &str,
1015        vb: ShardedVarBuilder,
1016        normal_loading_metadata: NormalLoadingMetadata,
1017        attention_mechanism: AttentionImplementation,
1018    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1019        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1020
1021        Ok(Box::new(models::mixtral::Model::new(
1022            &cfg,
1023            vb,
1024            self.is_gptx(config)?,
1025            normal_loading_metadata,
1026            attention_mechanism,
1027        )?))
1028    }
1029    fn load_xlora(
1030        &self,
1031        config: &str,
1032        vb: ShardedVarBuilder,
1033        lora_config: &[((String, String), LoraConfig)],
1034        xlora_config: Option<XLoraConfig>,
1035        xlora_ordering: Ordering,
1036        normal_loading_metadata: NormalLoadingMetadata,
1037        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1038    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1039        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1040
1041        Ok(Box::new(xlora_models::XLoraMixtral::new(
1042            &cfg,
1043            vb,
1044            lora_config,
1045            xlora_config,
1046            xlora_ordering,
1047            self.is_gptx(config)?,
1048            normal_loading_metadata,
1049            preload_adapters,
1050        )?))
1051    }
1052    fn is_gptx(&self, _: &str) -> Result<bool> {
1053        Ok(true)
1054    }
1055    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1056        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1057
1058        Ok(Box::new(cfg))
1059    }
1060}
1061
1062impl IsqModelLoader for MixtralLoader {
1063    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1064        Ok(vec![
1065            Regex::new(r"lm_head\.(weight|bias)$")?,
1066            // Attention
1067            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1068            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1069            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1070            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1071            // Experts
1072            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1073            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1074            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1075            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1076        ])
1077    }
1078    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1079        self.isq_layer_regexes(config)
1080    }
1081}
1082
1083impl DeviceMappedModelLoader for MixtralLoader {
1084    fn mapped_max_act_size_elems(
1085        &self,
1086        config: &str,
1087        params: &AutoDeviceMapParams,
1088    ) -> Result<usize> {
1089        let AutoDeviceMapParams::Text {
1090            max_seq_len,
1091            max_batch_size,
1092        } = params
1093        else {
1094            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1095        };
1096
1097        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1098
1099        Ok(
1100            max_batch_size
1101                * cfg.num_attention_heads
1102                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1103        )
1104    }
1105    fn non_mapped_max_act_size_elems(
1106        &self,
1107        _config: &str,
1108        _params: &AutoDeviceMapParams,
1109    ) -> Result<usize> {
1110        Ok(0)
1111    }
1112
1113    fn non_mapped_size_in_bytes(
1114        &self,
1115        config: &str,
1116        dtype: DType,
1117        weight_pack_factor: usize,
1118        _matformer_config: Option<&MatformerSliceConfig>,
1119    ) -> Result<usize> {
1120        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1121
1122        let elems = {
1123            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1124            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1125            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1126                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1127            } else {
1128                0
1129            };
1130            let norm = cfg.hidden_size;
1131            embed_tokens + lm_head + norm
1132        };
1133        Ok(elems * dtype.size_in_bytes())
1134    }
1135
1136    fn layer_sizes_in_bytes(
1137        &self,
1138        config: &str,
1139        dtype: DType,
1140        weight_pack_factor: usize,
1141        _matformer_config: Option<&MatformerSliceConfig>,
1142    ) -> Result<Vec<usize>> {
1143        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1144
1145        let per_layer_elems = {
1146            let input_layernorm = cfg.hidden_size;
1147            let post_attention_layernorm = cfg.hidden_size;
1148
1149            let size_in = cfg.hidden_size;
1150            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1151            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1152            let q_proj = size_in * size_q / weight_pack_factor;
1153            let k_proj = size_in * size_kv / weight_pack_factor;
1154            let v_proj = size_in * size_kv / weight_pack_factor;
1155            let o_proj = size_q * size_in / weight_pack_factor;
1156
1157            let moe_block = {
1158                let gate = cfg.hidden_size * cfg.num_local_experts;
1159                // Assume quantizing weight pack factor
1160                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1161                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1162                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1163                gate + cfg.num_local_experts * w1
1164                    + cfg.num_local_experts * w2
1165                    + cfg.num_local_experts * w3
1166            };
1167
1168            input_layernorm
1169                + post_attention_layernorm
1170                + q_proj
1171                + k_proj
1172                + v_proj
1173                + o_proj
1174                + moe_block
1175        };
1176        Ok(vec![
1177            per_layer_elems * dtype.size_in_bytes();
1178            cfg.num_hidden_layers
1179        ])
1180    }
1181
1182    fn num_layers(&self, config: &str) -> Result<usize> {
1183        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1184
1185        Ok(cfg.num_hidden_layers)
1186    }
1187
1188    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1189        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1190
1191        let cfg = ModelConfigMetadata {
1192            max_seq_len: cfg.max_position_embeddings,
1193            num_layers: cfg.num_hidden_layers,
1194            hidden_size: cfg.hidden_size,
1195            num_kv_heads: cfg.num_key_value_heads,
1196            num_attn_heads: cfg.num_attention_heads,
1197            sliding_window: cfg.sliding_window,
1198            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1199            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1200        };
1201
1202        Ok(Box::new(cfg))
1203    }
1204}
1205
1206// ======================== Phi2 loader
1207
1208/// [`NormalLoader`] for a Phi 2 model.
1209///
1210/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1211pub struct Phi2Loader;
1212
1213impl NormalModelLoader for Phi2Loader {
1214    fn load(
1215        &self,
1216        config: &str,
1217        vb: ShardedVarBuilder,
1218        normal_loading_metadata: NormalLoadingMetadata,
1219        attention_mechanism: AttentionImplementation,
1220    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1221        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1222
1223        Ok(Box::new(models::phi2::Model::new(
1224            &cfg,
1225            vb,
1226            self.is_gptx(config)?,
1227            normal_loading_metadata,
1228            attention_mechanism,
1229        )?))
1230    }
1231    fn load_xlora(
1232        &self,
1233        config: &str,
1234        vb: ShardedVarBuilder,
1235        lora_config: &[((String, String), LoraConfig)],
1236        xlora_config: Option<XLoraConfig>,
1237        xlora_ordering: Ordering,
1238        normal_loading_metadata: NormalLoadingMetadata,
1239        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1240    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1241        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243        Ok(Box::new(xlora_models::XLoraPhi2::new(
1244            &cfg,
1245            vb,
1246            lora_config,
1247            xlora_config,
1248            xlora_ordering,
1249            self.is_gptx(config)?,
1250            normal_loading_metadata,
1251            preload_adapters,
1252        )?))
1253    }
1254    fn is_gptx(&self, _: &str) -> Result<bool> {
1255        Ok(true)
1256    }
1257    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1258        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1259
1260        Ok(Box::new(cfg))
1261    }
1262}
1263
1264impl IsqModelLoader for Phi2Loader {
1265    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1266        Ok(vec![
1267            Regex::new(r"lm_head\.(weight|bias)$")?,
1268            // Attention
1269            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1270            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1271            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1272            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1273            // MLP
1274            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1275            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1276        ])
1277    }
1278    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1279        self.isq_layer_regexes(config)
1280    }
1281}
1282
1283impl DeviceMappedModelLoader for Phi2Loader {
1284    fn mapped_max_act_size_elems(
1285        &self,
1286        config: &str,
1287        params: &AutoDeviceMapParams,
1288    ) -> Result<usize> {
1289        let AutoDeviceMapParams::Text {
1290            max_seq_len,
1291            max_batch_size,
1292        } = params
1293        else {
1294            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1295        };
1296
1297        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1298
1299        Ok(
1300            max_batch_size
1301                * cfg.num_attention_heads
1302                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1303        )
1304    }
1305    fn non_mapped_max_act_size_elems(
1306        &self,
1307        _config: &str,
1308        _params: &AutoDeviceMapParams,
1309    ) -> Result<usize> {
1310        Ok(0)
1311    }
1312
1313    fn non_mapped_size_in_bytes(
1314        &self,
1315        config: &str,
1316        dtype: DType,
1317        weight_pack_factor: usize,
1318        _matformer_config: Option<&MatformerSliceConfig>,
1319    ) -> Result<usize> {
1320        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1321
1322        let elems = {
1323            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1324            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1325            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1326                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1327            } else {
1328                0
1329            };
1330            let norm = cfg.hidden_size;
1331            embed_tokens + lm_head + norm
1332        };
1333        Ok(elems * dtype.size_in_bytes())
1334    }
1335
1336    fn layer_sizes_in_bytes(
1337        &self,
1338        config: &str,
1339        dtype: DType,
1340        weight_pack_factor: usize,
1341        _matformer_config: Option<&MatformerSliceConfig>,
1342    ) -> Result<Vec<usize>> {
1343        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1344
1345        let per_layer_elems = {
1346            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1347
1348            let size_in = cfg.hidden_size;
1349            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1350            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1351            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1352            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1353            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1354            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1355            let (q_norm, k_norm) = if cfg.qk_layernorm {
1356                (cfg.head_dim(), cfg.head_dim())
1357            } else {
1358                (0, 0)
1359            };
1360
1361            let h_size = cfg.hidden_size;
1362            let i_size = cfg.intermediate_size;
1363            let fc1 = h_size * i_size / weight_pack_factor;
1364            let fc2 = h_size * i_size / weight_pack_factor;
1365
1366            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1367        };
1368        Ok(vec![
1369            per_layer_elems * dtype.size_in_bytes();
1370            cfg.num_hidden_layers
1371        ])
1372    }
1373
1374    fn num_layers(&self, config: &str) -> Result<usize> {
1375        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377        Ok(cfg.num_hidden_layers)
1378    }
1379
1380    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1381        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1382
1383        let cfg = ModelConfigMetadata {
1384            max_seq_len: cfg.max_position_embeddings,
1385            num_layers: cfg.num_hidden_layers,
1386            hidden_size: cfg.hidden_size,
1387            num_kv_heads: cfg.num_key_value_heads(),
1388            num_attn_heads: cfg.num_attention_heads,
1389            sliding_window: None,
1390            k_head_dim: cfg.head_dim(),
1391            v_head_dim: cfg.head_dim(),
1392        };
1393
1394        Ok(Box::new(cfg))
1395    }
1396}
1397
1398// ======================== Phi3 loader
1399
1400/// [`NormalLoader`] for a Phi 3 model.
1401///
1402/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1403pub struct Phi3Loader;
1404
1405impl NormalModelLoader for Phi3Loader {
1406    fn load(
1407        &self,
1408        config: &str,
1409        vb: ShardedVarBuilder,
1410        normal_loading_metadata: NormalLoadingMetadata,
1411        attention_mechanism: AttentionImplementation,
1412    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1413        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1414
1415        Ok(Box::new(models::phi3::Model::new(
1416            &cfg,
1417            vb,
1418            self.is_gptx(config)?,
1419            normal_loading_metadata,
1420            attention_mechanism,
1421        )?))
1422    }
1423    fn load_xlora(
1424        &self,
1425        config: &str,
1426        vb: ShardedVarBuilder,
1427        lora_config: &[((String, String), LoraConfig)],
1428        xlora_config: Option<XLoraConfig>,
1429        xlora_ordering: Ordering,
1430        normal_loading_metadata: NormalLoadingMetadata,
1431        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1432    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1433        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1434
1435        Ok(Box::new(xlora_models::XLoraPhi3::new(
1436            &cfg,
1437            vb,
1438            lora_config,
1439            xlora_config,
1440            xlora_ordering,
1441            self.is_gptx(config)?,
1442            normal_loading_metadata,
1443            preload_adapters,
1444        )?))
1445    }
1446    fn is_gptx(&self, _: &str) -> Result<bool> {
1447        Ok(true)
1448    }
1449    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1450        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1451
1452        Ok(Box::new(cfg))
1453    }
1454}
1455
1456impl IsqModelLoader for Phi3Loader {
1457    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1458        Ok(vec![
1459            Regex::new(r"lm_head\.(weight|bias)$")?,
1460            // Attention
1461            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1462            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1463            // MLP
1464            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1465            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1466            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1467        ])
1468    }
1469    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1470        self.isq_layer_regexes(config)
1471    }
1472}
1473
1474impl DeviceMappedModelLoader for Phi3Loader {
1475    fn mapped_max_act_size_elems(
1476        &self,
1477        config: &str,
1478        params: &AutoDeviceMapParams,
1479    ) -> Result<usize> {
1480        let AutoDeviceMapParams::Text {
1481            max_seq_len,
1482            max_batch_size,
1483        } = params
1484        else {
1485            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1486        };
1487
1488        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1489
1490        Ok(
1491            max_batch_size
1492                * cfg.num_attention_heads
1493                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1494        )
1495    }
1496    fn non_mapped_max_act_size_elems(
1497        &self,
1498        _config: &str,
1499        _params: &AutoDeviceMapParams,
1500    ) -> Result<usize> {
1501        Ok(0)
1502    }
1503
1504    fn non_mapped_size_in_bytes(
1505        &self,
1506        config: &str,
1507        dtype: DType,
1508        weight_pack_factor: usize,
1509        _matformer_config: Option<&MatformerSliceConfig>,
1510    ) -> Result<usize> {
1511        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513        let elems = {
1514            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1515            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1516            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1517                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1518            } else {
1519                0
1520            };
1521            let norm = cfg.hidden_size;
1522            embed_tokens + lm_head + norm
1523        };
1524        Ok(elems * dtype.size_in_bytes())
1525    }
1526
1527    fn layer_sizes_in_bytes(
1528        &self,
1529        config: &str,
1530        dtype: DType,
1531        weight_pack_factor: usize,
1532        _matformer_config: Option<&MatformerSliceConfig>,
1533    ) -> Result<Vec<usize>> {
1534        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1535
1536        let per_layer_elems = {
1537            let input_layernorm = cfg.hidden_size;
1538            let post_attention_layernorm = cfg.hidden_size;
1539
1540            let size_in = cfg.hidden_size;
1541            let head_dim = cfg.head_dim();
1542            let op_size =
1543                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1544            let qkv_proj = size_in * op_size / weight_pack_factor;
1545            let o_proj =
1546                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1547
1548            let h_size = cfg.hidden_size;
1549            let i_size = cfg.intermediate_size;
1550            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1551            let down_proj = h_size * i_size / weight_pack_factor;
1552
1553            input_layernorm
1554                + post_attention_layernorm
1555                + qkv_proj
1556                + o_proj
1557                + gate_up_proj
1558                + down_proj
1559        };
1560        Ok(vec![
1561            per_layer_elems * dtype.size_in_bytes();
1562            cfg.num_hidden_layers
1563        ])
1564    }
1565
1566    fn num_layers(&self, config: &str) -> Result<usize> {
1567        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569        Ok(cfg.num_hidden_layers)
1570    }
1571
1572    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1573        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1574
1575        let cfg = ModelConfigMetadata {
1576            max_seq_len: cfg.max_position_embeddings,
1577            num_layers: cfg.num_hidden_layers,
1578            hidden_size: cfg.hidden_size,
1579            num_kv_heads: cfg.num_key_value_heads,
1580            num_attn_heads: cfg.num_attention_heads,
1581            sliding_window: cfg.sliding_window,
1582            k_head_dim: cfg.head_dim(),
1583            v_head_dim: cfg.head_dim(),
1584        };
1585
1586        Ok(Box::new(cfg))
1587    }
1588}
1589
1590// ======================== Qwen2 loader
1591
1592/// [`NormalLoader`] for a Qwen 2 model.
1593///
1594/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1595pub struct Qwen2Loader;
1596
1597impl NormalModelLoader for Qwen2Loader {
1598    fn load(
1599        &self,
1600        config: &str,
1601        vb: ShardedVarBuilder,
1602        normal_loading_metadata: NormalLoadingMetadata,
1603        attention_mechanism: AttentionImplementation,
1604    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1605        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1606
1607        Ok(Box::new(models::qwen2::Model::new(
1608            &cfg,
1609            vb,
1610            self.is_gptx(config)?,
1611            normal_loading_metadata,
1612            attention_mechanism,
1613        )?))
1614    }
1615    fn load_xlora(
1616        &self,
1617        _config: &str,
1618        _vb: ShardedVarBuilder,
1619        _lora_config: &[((String, String), LoraConfig)],
1620        _xlora_config: Option<XLoraConfig>,
1621        _xlora_ordering: Ordering,
1622        _normal_loading_metadata: NormalLoadingMetadata,
1623        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1624    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1625        todo!()
1626    }
1627    fn is_gptx(&self, _: &str) -> Result<bool> {
1628        Ok(true)
1629    }
1630    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1631        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1632
1633        Ok(Box::new(cfg))
1634    }
1635}
1636
1637impl IsqModelLoader for Qwen2Loader {
1638    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1639        Ok(vec![
1640            Regex::new(r"lm_head\.(weight|bias)$")?,
1641            // Attention
1642            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1643            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1644            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1645            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1646            // MLP
1647            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1648            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1649            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1650        ])
1651    }
1652    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1653        self.isq_layer_regexes(config)
1654    }
1655}
1656
1657impl DeviceMappedModelLoader for Qwen2Loader {
1658    fn mapped_max_act_size_elems(
1659        &self,
1660        config: &str,
1661        params: &AutoDeviceMapParams,
1662    ) -> Result<usize> {
1663        let AutoDeviceMapParams::Text {
1664            max_seq_len,
1665            max_batch_size,
1666        } = params
1667        else {
1668            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1669        };
1670
1671        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1672
1673        Ok(
1674            max_batch_size
1675                * cfg.num_attention_heads
1676                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1677        )
1678    }
1679    fn non_mapped_max_act_size_elems(
1680        &self,
1681        _config: &str,
1682        _params: &AutoDeviceMapParams,
1683    ) -> Result<usize> {
1684        Ok(0)
1685    }
1686
1687    fn non_mapped_size_in_bytes(
1688        &self,
1689        config: &str,
1690        dtype: DType,
1691        weight_pack_factor: usize,
1692        _matformer_config: Option<&MatformerSliceConfig>,
1693    ) -> Result<usize> {
1694        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1695
1696        let elems = {
1697            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1698            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1699            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1700                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1701            } else {
1702                0
1703            };
1704            let norm = cfg.hidden_size;
1705            embed_tokens + lm_head + norm
1706        };
1707        Ok(elems * dtype.size_in_bytes())
1708    }
1709
1710    fn layer_sizes_in_bytes(
1711        &self,
1712        config: &str,
1713        dtype: DType,
1714        weight_pack_factor: usize,
1715        _matformer_config: Option<&MatformerSliceConfig>,
1716    ) -> Result<Vec<usize>> {
1717        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1718
1719        let per_layer_elems = {
1720            let input_layernorm = cfg.hidden_size;
1721            let post_attention_layernorm = cfg.hidden_size;
1722
1723            let size_in = cfg.hidden_size;
1724            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1725            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1726            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1727            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1728            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1729            let o_proj = size_q * size_in / weight_pack_factor;
1730
1731            let h_size = cfg.hidden_size;
1732            let i_size = cfg.intermediate_size;
1733            let gate_proj = h_size * i_size / weight_pack_factor;
1734            let up_proj = h_size * i_size / weight_pack_factor;
1735            let down_proj = i_size * h_size / weight_pack_factor;
1736
1737            input_layernorm
1738                + post_attention_layernorm
1739                + q_proj
1740                + k_proj
1741                + v_proj
1742                + o_proj
1743                + gate_proj
1744                + up_proj
1745                + down_proj
1746        };
1747        Ok(vec![
1748            per_layer_elems * dtype.size_in_bytes();
1749            cfg.num_hidden_layers
1750        ])
1751    }
1752
1753    fn num_layers(&self, config: &str) -> Result<usize> {
1754        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756        Ok(cfg.num_hidden_layers)
1757    }
1758
1759    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1760        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1761
1762        let cfg = ModelConfigMetadata {
1763            max_seq_len: cfg.max_position_embeddings,
1764            num_layers: cfg.num_hidden_layers,
1765            hidden_size: cfg.hidden_size,
1766            num_kv_heads: cfg.num_key_value_heads,
1767            num_attn_heads: cfg.num_attention_heads,
1768            sliding_window: cfg.sliding_window,
1769            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1770            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1771        };
1772
1773        Ok(Box::new(cfg))
1774    }
1775}
1776
1777// ======================== Gemma2 loader
1778
1779/// [`NormalLoader`] for a Gemma2 model.
1780///
1781/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1782pub struct Gemma2Loader;
1783
1784impl NormalModelLoader for Gemma2Loader {
1785    fn load(
1786        &self,
1787        config: &str,
1788        vb: ShardedVarBuilder,
1789        normal_loading_metadata: NormalLoadingMetadata,
1790        attention_mechanism: AttentionImplementation,
1791    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1792        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1793
1794        Ok(Box::new(models::gemma2::Model::new(
1795            &cfg,
1796            vb,
1797            self.is_gptx(config)?,
1798            normal_loading_metadata,
1799            attention_mechanism,
1800        )?))
1801    }
1802    fn load_xlora(
1803        &self,
1804        config: &str,
1805        vb: ShardedVarBuilder,
1806        lora_config: &[((String, String), LoraConfig)],
1807        xlora_config: Option<XLoraConfig>,
1808        xlora_ordering: Ordering,
1809        normal_loading_metadata: NormalLoadingMetadata,
1810        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1811    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1812        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1813
1814        Ok(Box::new(xlora_models::XLoraGemma2::new(
1815            &cfg,
1816            vb,
1817            lora_config,
1818            xlora_config,
1819            xlora_ordering,
1820            self.is_gptx(config)?,
1821            normal_loading_metadata,
1822            preload_adapters,
1823        )?))
1824    }
1825    fn is_gptx(&self, _: &str) -> Result<bool> {
1826        Ok(true)
1827    }
1828    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1829        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1830
1831        Ok(Box::new(cfg))
1832    }
1833}
1834
1835impl IsqModelLoader for Gemma2Loader {
1836    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1837        Ok(vec![
1838            Regex::new(r"lm_head\.(weight|bias)$")?,
1839            // Attention
1840            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1841            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1842            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1843            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1844            // MLP
1845            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1846            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1847            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1848        ])
1849    }
1850    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1851        self.isq_layer_regexes(config)
1852    }
1853}
1854
1855impl DeviceMappedModelLoader for Gemma2Loader {
1856    fn mapped_max_act_size_elems(
1857        &self,
1858        config: &str,
1859        params: &AutoDeviceMapParams,
1860    ) -> Result<usize> {
1861        let AutoDeviceMapParams::Text {
1862            max_seq_len,
1863            max_batch_size,
1864        } = params
1865        else {
1866            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1867        };
1868
1869        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1870
1871        Ok(
1872            max_batch_size
1873                * cfg.num_attention_heads
1874                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1875        )
1876    }
1877    fn non_mapped_max_act_size_elems(
1878        &self,
1879        _config: &str,
1880        _params: &AutoDeviceMapParams,
1881    ) -> Result<usize> {
1882        Ok(0)
1883    }
1884
1885    fn non_mapped_size_in_bytes(
1886        &self,
1887        config: &str,
1888        dtype: DType,
1889        weight_pack_factor: usize,
1890        _matformer_config: Option<&MatformerSliceConfig>,
1891    ) -> Result<usize> {
1892        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1893
1894        let elems = {
1895            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1896            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1897            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1898                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1899            } else {
1900                0
1901            };
1902            let norm = cfg.hidden_size;
1903            embed_tokens + lm_head + norm
1904        };
1905        Ok(elems * dtype.size_in_bytes())
1906    }
1907
1908    fn layer_sizes_in_bytes(
1909        &self,
1910        config: &str,
1911        dtype: DType,
1912        weight_pack_factor: usize,
1913        _matformer_config: Option<&MatformerSliceConfig>,
1914    ) -> Result<Vec<usize>> {
1915        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1916
1917        let per_layer_elems = {
1918            let input_layernorm = cfg.hidden_size;
1919            let post_attention_layernorm = cfg.hidden_size;
1920
1921            let size_in = cfg.hidden_size;
1922            let size_q = cfg.head_dim * cfg.num_attention_heads;
1923            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1924            let q_proj =
1925                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1926            let k_proj =
1927                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1928            let v_proj =
1929                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1930            let o_proj =
1931                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1932
1933            let h_size = cfg.hidden_size;
1934            let i_size = cfg.intermediate_size;
1935            let gate_proj = h_size * i_size / weight_pack_factor;
1936            let up_proj = h_size * i_size / weight_pack_factor;
1937            let down_proj = i_size * h_size / weight_pack_factor;
1938
1939            input_layernorm
1940                + post_attention_layernorm
1941                + q_proj
1942                + k_proj
1943                + v_proj
1944                + o_proj
1945                + gate_proj
1946                + up_proj
1947                + down_proj
1948        };
1949        Ok(vec![
1950            per_layer_elems * dtype.size_in_bytes();
1951            cfg.num_hidden_layers
1952        ])
1953    }
1954
1955    fn num_layers(&self, config: &str) -> Result<usize> {
1956        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1957
1958        Ok(cfg.num_hidden_layers)
1959    }
1960    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1961        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1962
1963        let cfg = ModelConfigMetadata {
1964            max_seq_len: cfg.max_position_embeddings,
1965            num_layers: cfg.num_hidden_layers,
1966            hidden_size: cfg.hidden_size,
1967            num_kv_heads: cfg.num_key_value_heads,
1968            num_attn_heads: cfg.num_attention_heads,
1969            sliding_window: None, // None to be more forgiving, some do not
1970            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1971            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1972        };
1973
1974        Ok(Box::new(cfg))
1975    }
1976}
1977
1978// ======================== Starcoder2 loader
1979
1980/// [`NormalLoader`] for a Starcoder2 model.
1981///
1982/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1983pub struct Starcoder2Loader;
1984
1985impl NormalModelLoader for Starcoder2Loader {
1986    fn load(
1987        &self,
1988        config: &str,
1989        vb: ShardedVarBuilder,
1990        normal_loading_metadata: NormalLoadingMetadata,
1991        attention_mechanism: AttentionImplementation,
1992    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1993        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1994
1995        Ok(Box::new(models::starcoder2::Model::new(
1996            &cfg,
1997            vb,
1998            self.is_gptx(config)?,
1999            normal_loading_metadata,
2000            attention_mechanism,
2001        )?))
2002    }
2003    fn load_xlora(
2004        &self,
2005        config: &str,
2006        vb: ShardedVarBuilder,
2007        lora_config: &[((String, String), LoraConfig)],
2008        xlora_config: Option<XLoraConfig>,
2009        xlora_ordering: Ordering,
2010        normal_loading_metadata: NormalLoadingMetadata,
2011        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2012    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2013        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2014
2015        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2016            &cfg,
2017            vb,
2018            lora_config,
2019            xlora_config,
2020            xlora_ordering,
2021            self.is_gptx(config)?,
2022            normal_loading_metadata,
2023            preload_adapters,
2024        )?))
2025    }
2026    fn is_gptx(&self, _: &str) -> Result<bool> {
2027        Ok(true)
2028    }
2029    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2030        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2031
2032        Ok(Box::new(cfg))
2033    }
2034}
2035
2036impl IsqModelLoader for Starcoder2Loader {
2037    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2038        Ok(vec![
2039            Regex::new(r"lm_head\.(weight|bias)$")?,
2040            // Attention
2041            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2042            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2043            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2044            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2045            // MLP
2046            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2047            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2048        ])
2049    }
2050    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2051        self.isq_layer_regexes(config)
2052    }
2053}
2054
2055impl DeviceMappedModelLoader for Starcoder2Loader {
2056    fn mapped_max_act_size_elems(
2057        &self,
2058        config: &str,
2059        params: &AutoDeviceMapParams,
2060    ) -> Result<usize> {
2061        let AutoDeviceMapParams::Text {
2062            max_seq_len,
2063            max_batch_size,
2064        } = params
2065        else {
2066            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2067        };
2068
2069        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071        Ok(
2072            max_batch_size
2073                * cfg.num_attention_heads
2074                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2075        )
2076    }
2077    fn non_mapped_max_act_size_elems(
2078        &self,
2079        _config: &str,
2080        _params: &AutoDeviceMapParams,
2081    ) -> Result<usize> {
2082        Ok(0)
2083    }
2084
2085    fn non_mapped_size_in_bytes(
2086        &self,
2087        config: &str,
2088        dtype: DType,
2089        weight_pack_factor: usize,
2090        _matformer_config: Option<&MatformerSliceConfig>,
2091    ) -> Result<usize> {
2092        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2093
2094        let elems = {
2095            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2096            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2097            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2098                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2099            } else {
2100                0
2101            };
2102            let norm = cfg.hidden_size + cfg.hidden_size;
2103            embed_tokens + lm_head + norm
2104        };
2105        Ok(elems * dtype.size_in_bytes())
2106    }
2107
2108    fn layer_sizes_in_bytes(
2109        &self,
2110        config: &str,
2111        dtype: DType,
2112        weight_pack_factor: usize,
2113        _matformer_config: Option<&MatformerSliceConfig>,
2114    ) -> Result<Vec<usize>> {
2115        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2116
2117        let per_layer_elems = {
2118            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2119            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2120
2121            let size_in = cfg.hidden_size;
2122            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2123            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2124            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2125            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2126            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2127            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2128
2129            let h_size = cfg.hidden_size;
2130            let i_size = cfg.intermediate_size;
2131            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2132            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2133
2134            input_layernorm
2135                + post_attention_layernorm
2136                + q_proj
2137                + k_proj
2138                + v_proj
2139                + o_proj
2140                + fc1
2141                + fc2
2142        };
2143        Ok(vec![
2144            per_layer_elems * dtype.size_in_bytes();
2145            cfg.num_hidden_layers
2146        ])
2147    }
2148
2149    fn num_layers(&self, config: &str) -> Result<usize> {
2150        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2151
2152        Ok(cfg.num_hidden_layers)
2153    }
2154
2155    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2156        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2157
2158        let cfg = ModelConfigMetadata {
2159            max_seq_len: cfg.max_position_embeddings,
2160            num_layers: cfg.num_hidden_layers,
2161            hidden_size: cfg.hidden_size,
2162            num_kv_heads: cfg.num_key_value_heads,
2163            num_attn_heads: cfg.num_attention_heads,
2164            sliding_window: cfg.sliding_window,
2165            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2166            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2167        };
2168
2169        Ok(Box::new(cfg))
2170    }
2171}
2172
2173// ======================== Phi3 loader
2174
2175/// [`NormalLoader`] for a Phi 3.5 MoE model.
2176///
2177/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2178pub struct Phi3_5MoELoader;
2179
2180impl NormalModelLoader for Phi3_5MoELoader {
2181    fn load(
2182        &self,
2183        config: &str,
2184        vb: ShardedVarBuilder,
2185        normal_loading_metadata: NormalLoadingMetadata,
2186        attention_mechanism: AttentionImplementation,
2187    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2188        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2189
2190        Ok(Box::new(models::phi3_5_moe::Model::new(
2191            &cfg,
2192            vb,
2193            self.is_gptx(config)?,
2194            normal_loading_metadata,
2195            attention_mechanism,
2196        )?))
2197    }
2198    fn load_xlora(
2199        &self,
2200        config: &str,
2201        vb: ShardedVarBuilder,
2202        lora_config: &[((String, String), LoraConfig)],
2203        xlora_config: Option<XLoraConfig>,
2204        xlora_ordering: Ordering,
2205        normal_loading_metadata: NormalLoadingMetadata,
2206        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2207    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2208        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2209
2210        Ok(Box::new(xlora_models::XLoraPhi3::new(
2211            &cfg,
2212            vb,
2213            lora_config,
2214            xlora_config,
2215            xlora_ordering,
2216            self.is_gptx(config)?,
2217            normal_loading_metadata,
2218            preload_adapters,
2219        )?))
2220    }
2221    fn is_gptx(&self, _: &str) -> Result<bool> {
2222        Ok(true)
2223    }
2224    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2225        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2226
2227        Ok(Box::new(cfg))
2228    }
2229}
2230
2231impl IsqModelLoader for Phi3_5MoELoader {
2232    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2233        Ok(vec![
2234            Regex::new(r"lm_head\.(weight|bias)$")?,
2235            // Attention
2236            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2237            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2238            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2239            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2240            // MLP
2241            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2242            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2243            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2244        ])
2245    }
2246    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2247        self.isq_layer_regexes(config)
2248    }
2249
2250    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2251        Ok(vec![
2252            Regex::new(r"lm_head\.(weight|bias)$")?,
2253            // MLP
2254            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2255            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2256            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2257        ])
2258    }
2259    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2260        self.isq_layer_regexes_moqe(config)
2261    }
2262}
2263
2264impl DeviceMappedModelLoader for Phi3_5MoELoader {
2265    fn mapped_max_act_size_elems(
2266        &self,
2267        config: &str,
2268        params: &AutoDeviceMapParams,
2269    ) -> Result<usize> {
2270        let AutoDeviceMapParams::Text {
2271            max_seq_len,
2272            max_batch_size,
2273        } = params
2274        else {
2275            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2276        };
2277
2278        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2279
2280        Ok(
2281            max_batch_size
2282                * cfg.num_attention_heads
2283                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2284        )
2285    }
2286    fn non_mapped_max_act_size_elems(
2287        &self,
2288        _config: &str,
2289        _params: &AutoDeviceMapParams,
2290    ) -> Result<usize> {
2291        Ok(0)
2292    }
2293
2294    fn non_mapped_size_in_bytes(
2295        &self,
2296        config: &str,
2297        dtype: DType,
2298        weight_pack_factor: usize,
2299        _matformer_config: Option<&MatformerSliceConfig>,
2300    ) -> Result<usize> {
2301        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2302
2303        let elems = {
2304            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2305            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2306            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2307                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2308            } else {
2309                0
2310            };
2311            let norm = cfg.hidden_size;
2312            embed_tokens + lm_head + norm
2313        };
2314        Ok(elems * dtype.size_in_bytes())
2315    }
2316
2317    fn layer_sizes_in_bytes(
2318        &self,
2319        config: &str,
2320        dtype: DType,
2321        weight_pack_factor: usize,
2322        _matformer_config: Option<&MatformerSliceConfig>,
2323    ) -> Result<Vec<usize>> {
2324        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2325
2326        let per_layer_elems = {
2327            let input_layernorm = cfg.hidden_size;
2328            let post_attention_layernorm = cfg.hidden_size;
2329
2330            let size_in = cfg.hidden_size;
2331            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2332            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2333            let q_proj =
2334                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2335            let k_proj =
2336                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2337            let v_proj =
2338                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2339            let o_proj =
2340                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2341
2342            let moe_block = {
2343                let gate = cfg.hidden_size * cfg.num_local_experts;
2344                // Assume quantizing weight pack factor
2345                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2346                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2347                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2348                gate + cfg.num_local_experts * w1
2349                    + cfg.num_local_experts * w2
2350                    + cfg.num_local_experts * w3
2351            };
2352
2353            input_layernorm
2354                + post_attention_layernorm
2355                + q_proj
2356                + k_proj
2357                + v_proj
2358                + o_proj
2359                + moe_block
2360        };
2361        Ok(vec![
2362            per_layer_elems * dtype.size_in_bytes();
2363            cfg.num_hidden_layers
2364        ])
2365    }
2366
2367    fn num_layers(&self, config: &str) -> Result<usize> {
2368        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2369
2370        Ok(cfg.num_hidden_layers)
2371    }
2372
2373    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2374        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2375
2376        let cfg = ModelConfigMetadata {
2377            max_seq_len: cfg.max_position_embeddings,
2378            num_layers: cfg.num_hidden_layers,
2379            hidden_size: cfg.hidden_size,
2380            num_kv_heads: cfg.num_key_value_heads,
2381            num_attn_heads: cfg.num_attention_heads,
2382            sliding_window: cfg.sliding_window,
2383            k_head_dim: cfg.head_dim(),
2384            v_head_dim: cfg.head_dim(),
2385        };
2386
2387        Ok(Box::new(cfg))
2388    }
2389}
2390
2391/// [`NormalLoader`] for a DeepSeekV2 model.
2392///
2393/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2394pub struct DeepSeekV2Loader;
2395
2396impl NormalModelLoader for DeepSeekV2Loader {
2397    fn load(
2398        &self,
2399        config: &str,
2400        vb: ShardedVarBuilder,
2401        normal_loading_metadata: NormalLoadingMetadata,
2402        attention_mechanism: AttentionImplementation,
2403    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2404        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2405
2406        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2407            &cfg,
2408            vb,
2409            self.is_gptx(config)?,
2410            normal_loading_metadata,
2411            attention_mechanism,
2412        )?))
2413    }
2414    fn load_xlora(
2415        &self,
2416        _config: &str,
2417        _vb: ShardedVarBuilder,
2418        _lora_config: &[((String, String), LoraConfig)],
2419        _xlora_config: Option<XLoraConfig>,
2420        _xlora_ordering: Ordering,
2421        _normal_loading_metadata: NormalLoadingMetadata,
2422        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2423    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2424        todo!()
2425    }
2426    fn is_gptx(&self, _: &str) -> Result<bool> {
2427        Ok(true)
2428    }
2429    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2430        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2431        Ok(Box::new(cfg))
2432    }
2433}
2434
2435impl IsqModelLoader for DeepSeekV2Loader {
2436    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2437        let mut data = vec![
2438            Regex::new(r"lm_head\.(weight|bias)$")?,
2439            // Attention
2440            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2441            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2442            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2443        ];
2444        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2445        if cfg.q_lora_rank.is_some() {
2446            data.extend(vec![
2447                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2448                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2449            ]);
2450        } else {
2451            data.push(Regex::new(
2452                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2453            )?);
2454        }
2455        for layer_idx in 0..cfg.num_hidden_layers {
2456            if cfg.n_routed_experts.is_some()
2457                && layer_idx >= cfg.first_k_dense_replace
2458                && layer_idx % cfg.moe_layer_freq == 0
2459            {
2460                for i in 0..cfg.n_routed_experts.unwrap() {
2461                    data.extend(vec![
2462                        Regex::new(&format!(
2463                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2464                        ))?,
2465                        Regex::new(&format!(
2466                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2467                        ))?,
2468                        Regex::new(&format!(
2469                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2470                        ))?,
2471                    ]);
2472                }
2473                if cfg.n_shared_experts.is_some() {
2474                    data.extend(vec![
2475                        Regex::new(&format!(
2476                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2477                        ))?,
2478                        Regex::new(&format!(
2479                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2480                        ))?,
2481                        Regex::new(&format!(
2482                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2483                        ))?,
2484                    ]);
2485                }
2486            } else {
2487                data.extend(vec![
2488                    Regex::new(&format!(
2489                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2490                    ))?,
2491                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2492                    Regex::new(&format!(
2493                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2494                    ))?,
2495                ]);
2496            };
2497        }
2498        Ok(data)
2499    }
2500    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2501        self.isq_layer_regexes(config)
2502    }
2503
2504    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2505        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2506        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2507        for layer_idx in 0..cfg.num_hidden_layers {
2508            if cfg.n_routed_experts.is_some()
2509                && layer_idx >= cfg.first_k_dense_replace
2510                && layer_idx % cfg.moe_layer_freq == 0
2511            {
2512                for i in 0..cfg.n_routed_experts.unwrap() {
2513                    data.extend(vec![
2514                        Regex::new(&format!(
2515                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2516                        ))?,
2517                        Regex::new(&format!(
2518                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2519                        ))?,
2520                        Regex::new(&format!(
2521                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2522                        ))?,
2523                    ]);
2524                }
2525                if cfg.n_shared_experts.is_some() {
2526                    data.extend(vec![
2527                        Regex::new(&format!(
2528                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2529                        ))?,
2530                        Regex::new(&format!(
2531                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2532                        ))?,
2533                        Regex::new(&format!(
2534                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2535                        ))?,
2536                    ]);
2537                }
2538            } else {
2539                data.extend(vec![
2540                    Regex::new(&format!(
2541                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2542                    ))?,
2543                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2544                    Regex::new(&format!(
2545                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2546                    ))?,
2547                ]);
2548            };
2549        }
2550        Ok(data)
2551    }
2552    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2553        self.isq_layer_regexes_moqe(config)
2554    }
2555}
2556
2557impl DeviceMappedModelLoader for DeepSeekV2Loader {
2558    fn mapped_max_act_size_elems(
2559        &self,
2560        config: &str,
2561        params: &AutoDeviceMapParams,
2562    ) -> Result<usize> {
2563        let AutoDeviceMapParams::Text {
2564            max_seq_len,
2565            max_batch_size,
2566        } = params
2567        else {
2568            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2569        };
2570
2571        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2572
2573        Ok(
2574            max_batch_size
2575                * cfg.num_attention_heads
2576                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2577        )
2578    }
2579    fn non_mapped_max_act_size_elems(
2580        &self,
2581        _config: &str,
2582        _params: &AutoDeviceMapParams,
2583    ) -> Result<usize> {
2584        Ok(0)
2585    }
2586
2587    fn non_mapped_size_in_bytes(
2588        &self,
2589        config: &str,
2590        dtype: DType,
2591        weight_pack_factor: usize,
2592        _matformer_config: Option<&MatformerSliceConfig>,
2593    ) -> Result<usize> {
2594        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2595        let elems = {
2596            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2597            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2598            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2599                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2600            } else {
2601                0
2602            };
2603            let norm = cfg.hidden_size;
2604            embed_tokens + lm_head + norm
2605        };
2606        Ok(elems * dtype.size_in_bytes())
2607    }
2608
2609    fn layer_sizes_in_bytes(
2610        &self,
2611        config: &str,
2612        dtype: DType,
2613        weight_pack_factor: usize,
2614        _matformer_config: Option<&MatformerSliceConfig>,
2615    ) -> Result<Vec<usize>> {
2616        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2617        let mut per_layer_elems = Vec::new();
2618
2619        for layer_idx in 0..cfg.num_hidden_layers {
2620            let input_layernorm = cfg.hidden_size;
2621            let post_attention_layernorm = cfg.hidden_size;
2622
2623            let q_proj = match cfg.q_lora_rank {
2624                Some(lora_rank) => {
2625                    let a = cfg.hidden_size * lora_rank;
2626                    let norm = lora_rank;
2627                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2628                    a + norm + b
2629                }
2630                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2631            };
2632            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2633                / weight_pack_factor
2634                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2635            let kv_a_layernorm = cfg.kv_lora_rank;
2636            let kv_b_proj = cfg.kv_lora_rank
2637                * cfg.num_attention_heads
2638                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2639                / weight_pack_factor;
2640            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2641                / weight_pack_factor
2642                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2643
2644            let moe_block = {
2645                let mut sum = 0;
2646                if cfg.n_routed_experts.is_some()
2647                    && layer_idx >= cfg.first_k_dense_replace
2648                    && layer_idx % cfg.moe_layer_freq == 0
2649                {
2650                    let h_size = cfg.hidden_size;
2651                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2652                        * cfg.n_routed_experts.unwrap();
2653                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2654                        * cfg.n_routed_experts.unwrap();
2655                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2656                        * cfg.n_routed_experts.unwrap();
2657                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2658                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2659                            / weight_pack_factor;
2660                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2661                            / weight_pack_factor;
2662                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2663                            / weight_pack_factor;
2664                        gate_proj + up_proj + down_proj
2665                    } else {
2666                        0
2667                    };
2668                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2669                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2670                } else {
2671                    let h_size = cfg.hidden_size;
2672                    let i_size = cfg.intermediate_size;
2673                    let gate_proj = h_size * i_size / weight_pack_factor;
2674                    let up_proj = h_size * i_size / weight_pack_factor;
2675                    let down_proj = i_size * h_size / weight_pack_factor;
2676                    sum += gate_proj + up_proj + down_proj;
2677                }
2678                sum
2679            };
2680
2681            per_layer_elems.push(
2682                input_layernorm
2683                    + post_attention_layernorm
2684                    + q_proj
2685                    + kv_a_layernorm
2686                    + kv_a_proj_with_mqa
2687                    + kv_b_proj
2688                    + o_proj
2689                    + moe_block,
2690            );
2691        }
2692
2693        Ok(per_layer_elems
2694            .into_iter()
2695            .map(|x| x * dtype.size_in_bytes())
2696            .collect())
2697    }
2698
2699    fn num_layers(&self, config: &str) -> Result<usize> {
2700        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2701        Ok(cfg.num_hidden_layers)
2702    }
2703
2704    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2705        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2706
2707        let cfg = ModelConfigMetadata {
2708            max_seq_len: cfg.max_position_embeddings,
2709            num_layers: cfg.num_hidden_layers,
2710            hidden_size: cfg.hidden_size,
2711            num_kv_heads: cfg.num_attention_heads,
2712            num_attn_heads: cfg.num_attention_heads,
2713            sliding_window: None,
2714            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2715            v_head_dim: cfg.v_head_dim,
2716        };
2717
2718        Ok(Box::new(cfg))
2719    }
2720}
2721
2722/// [`NormalLoader`] for a DeepSeekV3 model.
2723///
2724/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2725pub struct DeepSeekV3Loader;
2726
2727impl NormalModelLoader for DeepSeekV3Loader {
2728    fn load(
2729        &self,
2730        config: &str,
2731        vb: ShardedVarBuilder,
2732        normal_loading_metadata: NormalLoadingMetadata,
2733        attention_mechanism: AttentionImplementation,
2734    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2735        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2736        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2737            &cfg,
2738            vb,
2739            self.is_gptx(config)?,
2740            normal_loading_metadata,
2741            attention_mechanism,
2742        )?))
2743    }
2744    fn load_xlora(
2745        &self,
2746        _config: &str,
2747        _vb: ShardedVarBuilder,
2748        _lora_config: &[((String, String), LoraConfig)],
2749        _xlora_config: Option<XLoraConfig>,
2750        _xlora_ordering: Ordering,
2751        _normal_loading_metadata: NormalLoadingMetadata,
2752        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2753    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2754        todo!()
2755    }
2756    fn is_gptx(&self, _: &str) -> Result<bool> {
2757        Ok(true)
2758    }
2759    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2760        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2761        Ok(Box::new(cfg))
2762    }
2763}
2764
2765impl IsqModelLoader for DeepSeekV3Loader {
2766    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2767        let mut data = vec![
2768            Regex::new(r"lm_head\.(weight|bias)$")?,
2769            // Attention
2770            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2771            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2772            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2773        ];
2774        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2775        if cfg.q_lora_rank.is_some() {
2776            data.extend(vec![
2777                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2778                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2779            ]);
2780        } else {
2781            data.push(Regex::new(
2782                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2783            )?);
2784        }
2785        for layer_idx in 0..cfg.num_hidden_layers {
2786            if cfg.n_routed_experts.is_some()
2787                && layer_idx >= cfg.first_k_dense_replace
2788                && layer_idx % cfg.moe_layer_freq == 0
2789            {
2790                for i in 0..cfg.n_routed_experts.unwrap() {
2791                    data.extend(vec![
2792                        Regex::new(&format!(
2793                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2794                        ))?,
2795                        Regex::new(&format!(
2796                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2797                        ))?,
2798                        Regex::new(&format!(
2799                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2800                        ))?,
2801                    ]);
2802                }
2803                if cfg.n_shared_experts.is_some() {
2804                    data.extend(vec![
2805                        Regex::new(&format!(
2806                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2807                        ))?,
2808                        Regex::new(&format!(
2809                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2810                        ))?,
2811                        Regex::new(&format!(
2812                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2813                        ))?,
2814                    ]);
2815                }
2816            } else {
2817                data.extend(vec![
2818                    Regex::new(&format!(
2819                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2820                    ))?,
2821                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2822                    Regex::new(&format!(
2823                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2824                    ))?,
2825                ]);
2826            };
2827        }
2828        Ok(data)
2829    }
2830    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2831        self.isq_layer_regexes(config)
2832    }
2833
2834    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2835        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2836        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2837        for layer_idx in 0..cfg.num_hidden_layers {
2838            if cfg.n_routed_experts.is_some()
2839                && layer_idx >= cfg.first_k_dense_replace
2840                && layer_idx % cfg.moe_layer_freq == 0
2841            {
2842                for i in 0..cfg.n_routed_experts.unwrap() {
2843                    data.extend(vec![
2844                        Regex::new(&format!(
2845                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2846                        ))?,
2847                        Regex::new(&format!(
2848                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2849                        ))?,
2850                        Regex::new(&format!(
2851                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2852                        ))?,
2853                    ]);
2854                }
2855                if cfg.n_shared_experts.is_some() {
2856                    data.extend(vec![
2857                        Regex::new(&format!(
2858                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2859                        ))?,
2860                        Regex::new(&format!(
2861                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2862                        ))?,
2863                        Regex::new(&format!(
2864                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2865                        ))?,
2866                    ]);
2867                }
2868            } else {
2869                data.extend(vec![
2870                    Regex::new(&format!(
2871                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2872                    ))?,
2873                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2874                    Regex::new(&format!(
2875                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2876                    ))?,
2877                ]);
2878            };
2879        }
2880        Ok(data)
2881    }
2882    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2883        self.isq_layer_regexes_moqe(config)
2884    }
2885}
2886
2887impl DeviceMappedModelLoader for DeepSeekV3Loader {
2888    fn mapped_max_act_size_elems(
2889        &self,
2890        config: &str,
2891        params: &AutoDeviceMapParams,
2892    ) -> Result<usize> {
2893        let AutoDeviceMapParams::Text {
2894            max_seq_len,
2895            max_batch_size,
2896        } = params
2897        else {
2898            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2899        };
2900
2901        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2902
2903        Ok(
2904            max_batch_size
2905                * cfg.num_attention_heads
2906                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2907        )
2908    }
2909    fn non_mapped_max_act_size_elems(
2910        &self,
2911        _config: &str,
2912        _params: &AutoDeviceMapParams,
2913    ) -> Result<usize> {
2914        Ok(0)
2915    }
2916
2917    fn non_mapped_size_in_bytes(
2918        &self,
2919        config: &str,
2920        dtype: DType,
2921        weight_pack_factor: usize,
2922        _matformer_config: Option<&MatformerSliceConfig>,
2923    ) -> Result<usize> {
2924        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2925        let elems = {
2926            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2927            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2928            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2929                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2930            } else {
2931                0
2932            };
2933            let norm = cfg.hidden_size;
2934            embed_tokens + lm_head + norm
2935        };
2936        Ok(elems * dtype.size_in_bytes())
2937    }
2938
2939    fn layer_sizes_in_bytes(
2940        &self,
2941        config: &str,
2942        dtype: DType,
2943        weight_pack_factor: usize,
2944        _matformer_config: Option<&MatformerSliceConfig>,
2945    ) -> Result<Vec<usize>> {
2946        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2947        let mut per_layer_elems = Vec::new();
2948
2949        for layer_idx in 0..cfg.num_hidden_layers {
2950            let input_layernorm = cfg.hidden_size;
2951            let post_attention_layernorm = cfg.hidden_size;
2952
2953            let q_proj = match cfg.q_lora_rank {
2954                Some(lora_rank) => {
2955                    let a = cfg.hidden_size * lora_rank;
2956                    let norm = lora_rank;
2957                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2958                    a + norm + b
2959                }
2960                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2961            };
2962            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2963                / weight_pack_factor
2964                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2965            let kv_a_layernorm = cfg.kv_lora_rank;
2966            let kv_b_proj = cfg.kv_lora_rank
2967                * cfg.num_attention_heads
2968                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2969                / weight_pack_factor;
2970            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2971                / weight_pack_factor
2972                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2973
2974            let moe_block = {
2975                let mut sum = 0;
2976                if cfg.n_routed_experts.is_some()
2977                    && layer_idx >= cfg.first_k_dense_replace
2978                    && layer_idx % cfg.moe_layer_freq == 0
2979                {
2980                    let h_size = cfg.hidden_size;
2981                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2982                        * cfg.n_routed_experts.unwrap();
2983                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2984                        * cfg.n_routed_experts.unwrap();
2985                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2986                        * cfg.n_routed_experts.unwrap();
2987                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2988                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2989                            / weight_pack_factor;
2990                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2991                            / weight_pack_factor;
2992                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2993                            / weight_pack_factor;
2994                        gate_proj + up_proj + down_proj
2995                    } else {
2996                        0
2997                    };
2998                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2999                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3000                } else {
3001                    let h_size = cfg.hidden_size;
3002                    let i_size = cfg.intermediate_size;
3003                    let gate_proj = h_size * i_size / weight_pack_factor;
3004                    let up_proj = h_size * i_size / weight_pack_factor;
3005                    let down_proj = i_size * h_size / weight_pack_factor;
3006                    sum += gate_proj + up_proj + down_proj;
3007                }
3008                sum
3009            };
3010
3011            per_layer_elems.push(
3012                input_layernorm
3013                    + post_attention_layernorm
3014                    + q_proj
3015                    + kv_a_layernorm
3016                    + kv_a_proj_with_mqa
3017                    + kv_b_proj
3018                    + o_proj
3019                    + moe_block,
3020            );
3021        }
3022
3023        Ok(per_layer_elems
3024            .into_iter()
3025            .map(|x| x * dtype.size_in_bytes())
3026            .collect())
3027    }
3028
3029    fn num_layers(&self, config: &str) -> Result<usize> {
3030        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3031        Ok(cfg.num_hidden_layers)
3032    }
3033
3034    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3035        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3036
3037        let cfg = ModelConfigMetadata {
3038            max_seq_len: cfg.max_position_embeddings,
3039            num_layers: cfg.num_hidden_layers,
3040            hidden_size: cfg.hidden_size,
3041            num_kv_heads: cfg.num_attention_heads,
3042            num_attn_heads: cfg.num_attention_heads,
3043            sliding_window: None,
3044            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3045            v_head_dim: cfg.v_head_dim,
3046        };
3047
3048        Ok(Box::new(cfg))
3049    }
3050}
3051
3052/// [`NormalLoader`] for a Qwen 3 model.
3053///
3054/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3055pub struct Qwen3Loader;
3056
3057impl NormalModelLoader for Qwen3Loader {
3058    fn load(
3059        &self,
3060        config: &str,
3061        vb: ShardedVarBuilder,
3062        normal_loading_metadata: NormalLoadingMetadata,
3063        attention_mechanism: AttentionImplementation,
3064    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3065        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3066
3067        Ok(Box::new(models::qwen3::Model::new(
3068            &cfg,
3069            vb,
3070            self.is_gptx(config)?,
3071            normal_loading_metadata,
3072            attention_mechanism,
3073        )?))
3074    }
3075    fn load_xlora(
3076        &self,
3077        _config: &str,
3078        _vb: ShardedVarBuilder,
3079        _lora_config: &[((String, String), LoraConfig)],
3080        _xlora_config: Option<XLoraConfig>,
3081        _xlora_ordering: Ordering,
3082        _normal_loading_metadata: NormalLoadingMetadata,
3083        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3084    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3085        todo!()
3086    }
3087    fn is_gptx(&self, _: &str) -> Result<bool> {
3088        Ok(true)
3089    }
3090    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3091        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3092
3093        Ok(Box::new(cfg))
3094    }
3095}
3096
3097impl IsqModelLoader for Qwen3Loader {
3098    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3099        Ok(vec![
3100            Regex::new(r"lm_head\.(weight|bias)$")?,
3101            // Attention
3102            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3103            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3104            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3105            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3106            // MLP
3107            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3108            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3109            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3110        ])
3111    }
3112    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3113        self.isq_layer_regexes(config)
3114    }
3115}
3116
3117impl DeviceMappedModelLoader for Qwen3Loader {
3118    fn mapped_max_act_size_elems(
3119        &self,
3120        config: &str,
3121        params: &AutoDeviceMapParams,
3122    ) -> Result<usize> {
3123        let AutoDeviceMapParams::Text {
3124            max_seq_len,
3125            max_batch_size,
3126        } = params
3127        else {
3128            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3129        };
3130
3131        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3132
3133        Ok(
3134            max_batch_size
3135                * cfg.num_attention_heads
3136                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3137        )
3138    }
3139    fn non_mapped_max_act_size_elems(
3140        &self,
3141        _config: &str,
3142        _params: &AutoDeviceMapParams,
3143    ) -> Result<usize> {
3144        Ok(0)
3145    }
3146
3147    fn non_mapped_size_in_bytes(
3148        &self,
3149        config: &str,
3150        dtype: DType,
3151        weight_pack_factor: usize,
3152        _matformer_config: Option<&MatformerSliceConfig>,
3153    ) -> Result<usize> {
3154        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3155        let elems = {
3156            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3157            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3158            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3159                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3160            } else {
3161                0
3162            };
3163            let norm = cfg.hidden_size;
3164            embed_tokens + lm_head + norm
3165        };
3166        Ok(elems * dtype.size_in_bytes())
3167    }
3168
3169    fn layer_sizes_in_bytes(
3170        &self,
3171        config: &str,
3172        dtype: DType,
3173        weight_pack_factor: usize,
3174        _matformer_config: Option<&MatformerSliceConfig>,
3175    ) -> Result<Vec<usize>> {
3176        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3177        let per_layer_elems = {
3178            let input_layernorm = cfg.hidden_size;
3179            let post_attention_layernorm = cfg.hidden_size;
3180
3181            let size_in = cfg.hidden_size;
3182            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3183            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3184            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3185            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3186            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3187            let o_proj = size_q * size_in / weight_pack_factor;
3188
3189            let h_size = cfg.hidden_size;
3190            let i_size = cfg.intermediate_size;
3191            let gate_proj = h_size * i_size / weight_pack_factor;
3192            let up_proj = h_size * i_size / weight_pack_factor;
3193            let down_proj = i_size * h_size / weight_pack_factor;
3194
3195            let q_norm = cfg.head_dim();
3196            let k_norm = cfg.head_dim();
3197
3198            input_layernorm
3199                + post_attention_layernorm
3200                + q_proj
3201                + k_proj
3202                + v_proj
3203                + o_proj
3204                + gate_proj
3205                + up_proj
3206                + down_proj
3207                + q_norm
3208                + k_norm
3209        };
3210        Ok(vec![
3211            per_layer_elems * dtype.size_in_bytes();
3212            cfg.num_hidden_layers
3213        ])
3214    }
3215
3216    fn num_layers(&self, config: &str) -> Result<usize> {
3217        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3218        Ok(cfg.num_hidden_layers)
3219    }
3220
3221    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3222        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3223
3224        let cfg = ModelConfigMetadata {
3225            max_seq_len: cfg.max_position_embeddings,
3226            num_layers: cfg.num_hidden_layers,
3227            hidden_size: cfg.hidden_size,
3228            num_kv_heads: cfg.num_key_value_heads,
3229            num_attn_heads: cfg.num_attention_heads,
3230            sliding_window: cfg.sliding_window,
3231            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3232            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3233        };
3234
3235        Ok(Box::new(cfg))
3236    }
3237}
3238
3239/// [`NormalLoader`] for a GLM 4 model.
3240///
3241/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3242pub struct GLM4Loader;
3243
3244impl NormalModelLoader for GLM4Loader {
3245    fn load(
3246        &self,
3247        config: &str,
3248        vb: ShardedVarBuilder,
3249        normal_loading_metadata: NormalLoadingMetadata,
3250        attention_mechanism: AttentionImplementation,
3251    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3252        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3253
3254        Ok(Box::new(models::glm4::Model::new(
3255            &cfg,
3256            vb,
3257            self.is_gptx(config)?,
3258            normal_loading_metadata,
3259            attention_mechanism,
3260        )?))
3261    }
3262    fn load_xlora(
3263        &self,
3264        _config: &str,
3265        _vb: ShardedVarBuilder,
3266        _lora_config: &[((String, String), LoraConfig)],
3267        _xlora_config: Option<XLoraConfig>,
3268        _xlora_ordering: Ordering,
3269        _normal_loading_metadata: NormalLoadingMetadata,
3270        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3271    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3272        todo!()
3273    }
3274    fn is_gptx(&self, _: &str) -> Result<bool> {
3275        Ok(true)
3276    }
3277    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3278        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3279
3280        Ok(Box::new(cfg))
3281    }
3282}
3283
3284impl IsqModelLoader for GLM4Loader {
3285    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3286        Ok(vec![
3287            Regex::new(r"lm_head\.(weight|bias)$")?,
3288            // Attention
3289            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3290            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3291            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3292            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3293            // MLP
3294            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3295            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3296            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3297        ])
3298    }
3299    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3300        self.isq_layer_regexes(config)
3301    }
3302}
3303
3304impl DeviceMappedModelLoader for GLM4Loader {
3305    fn mapped_max_act_size_elems(
3306        &self,
3307        config: &str,
3308        params: &AutoDeviceMapParams,
3309    ) -> Result<usize> {
3310        let AutoDeviceMapParams::Text {
3311            max_seq_len,
3312            max_batch_size,
3313        } = params
3314        else {
3315            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3316        };
3317
3318        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3319
3320        Ok(
3321            max_batch_size
3322                * cfg.num_attention_heads
3323                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3324        )
3325    }
3326    fn non_mapped_max_act_size_elems(
3327        &self,
3328        _config: &str,
3329        _params: &AutoDeviceMapParams,
3330    ) -> Result<usize> {
3331        Ok(0)
3332    }
3333
3334    fn non_mapped_size_in_bytes(
3335        &self,
3336        config: &str,
3337        dtype: DType,
3338        weight_pack_factor: usize,
3339        _matformer_config: Option<&MatformerSliceConfig>,
3340    ) -> Result<usize> {
3341        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3342        let elems = {
3343            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3344            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3345            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3346                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3347            } else {
3348                0
3349            };
3350            let norm = cfg.hidden_size;
3351            embed_tokens + lm_head + norm
3352        };
3353        Ok(elems * dtype.size_in_bytes())
3354    }
3355
3356    fn layer_sizes_in_bytes(
3357        &self,
3358        config: &str,
3359        dtype: DType,
3360        weight_pack_factor: usize,
3361        _matformer_config: Option<&MatformerSliceConfig>,
3362    ) -> Result<Vec<usize>> {
3363        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3364        let per_layer_elems = {
3365            let input_layernorm = cfg.hidden_size;
3366            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3367
3368            let size_in = cfg.hidden_size;
3369            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3370            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3371            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3372            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3373            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3374            let o_proj = size_q * size_in / weight_pack_factor;
3375
3376            let h_size = cfg.hidden_size;
3377            let i_size = cfg.intermediate_size;
3378            let gate_proj = h_size * i_size / weight_pack_factor;
3379            let up_proj = h_size * i_size / weight_pack_factor;
3380            let down_proj = i_size * h_size / weight_pack_factor;
3381
3382            input_layernorm
3383                + post_attention_layernorm
3384                + q_proj
3385                + k_proj
3386                + v_proj
3387                + o_proj
3388                + gate_proj
3389                + up_proj
3390                + down_proj
3391        };
3392        Ok(vec![
3393            per_layer_elems * dtype.size_in_bytes();
3394            cfg.num_hidden_layers
3395        ])
3396    }
3397
3398    fn num_layers(&self, config: &str) -> Result<usize> {
3399        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3400        Ok(cfg.num_hidden_layers)
3401    }
3402
3403    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3404        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3405
3406        let cfg = ModelConfigMetadata {
3407            max_seq_len: cfg.max_position_embeddings,
3408            num_layers: cfg.num_hidden_layers,
3409            hidden_size: cfg.hidden_size,
3410            num_kv_heads: cfg.num_key_value_heads,
3411            num_attn_heads: cfg.num_attention_heads,
3412            sliding_window: cfg.sliding_window,
3413            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3414            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3415        };
3416
3417        Ok(Box::new(cfg))
3418    }
3419}
3420
3421/// [`NormalLoader`] for a Qwen 3 MoE model.
3422///
3423/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3424pub struct Qwen3MoELoader;
3425
3426impl NormalModelLoader for Qwen3MoELoader {
3427    fn load(
3428        &self,
3429        config: &str,
3430        vb: ShardedVarBuilder,
3431        normal_loading_metadata: NormalLoadingMetadata,
3432        attention_mechanism: AttentionImplementation,
3433    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3434        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3435
3436        Ok(Box::new(models::qwen3_moe::Model::new(
3437            &cfg,
3438            vb,
3439            self.is_gptx(config)?,
3440            normal_loading_metadata,
3441            attention_mechanism,
3442        )?))
3443    }
3444    fn load_xlora(
3445        &self,
3446        _config: &str,
3447        _vb: ShardedVarBuilder,
3448        _lora_config: &[((String, String), LoraConfig)],
3449        _xlora_config: Option<XLoraConfig>,
3450        _xlora_ordering: Ordering,
3451        _normal_loading_metadata: NormalLoadingMetadata,
3452        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3453    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3454        todo!()
3455    }
3456    fn is_gptx(&self, _: &str) -> Result<bool> {
3457        Ok(true)
3458    }
3459    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3460        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3461
3462        Ok(Box::new(cfg))
3463    }
3464}
3465
3466impl IsqModelLoader for Qwen3MoELoader {
3467    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3468        Ok(vec![
3469            Regex::new(r"lm_head\.(weight|bias)$")?,
3470            // Attention
3471            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3472            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3473            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3474            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3475            // MLP
3476            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3477            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3478            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3479            // MLP MoE
3480            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3481            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3482            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3483        ])
3484    }
3485    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3486        self.isq_layer_regexes(config)
3487    }
3488    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3489        self.isq_layer_regexes_moqe(config)
3490    }
3491}
3492
3493impl DeviceMappedModelLoader for Qwen3MoELoader {
3494    fn mapped_max_act_size_elems(
3495        &self,
3496        config: &str,
3497        params: &AutoDeviceMapParams,
3498    ) -> Result<usize> {
3499        let AutoDeviceMapParams::Text {
3500            max_seq_len,
3501            max_batch_size,
3502        } = params
3503        else {
3504            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3505        };
3506
3507        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3508
3509        Ok(
3510            max_batch_size
3511                * cfg.num_attention_heads
3512                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3513        )
3514    }
3515    fn non_mapped_max_act_size_elems(
3516        &self,
3517        _config: &str,
3518        _params: &AutoDeviceMapParams,
3519    ) -> Result<usize> {
3520        Ok(0)
3521    }
3522
3523    fn non_mapped_size_in_bytes(
3524        &self,
3525        config: &str,
3526        dtype: DType,
3527        weight_pack_factor: usize,
3528        _matformer_config: Option<&MatformerSliceConfig>,
3529    ) -> Result<usize> {
3530        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3531        let elems = {
3532            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3533            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3534            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3535                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3536            } else {
3537                0
3538            };
3539            let norm = cfg.hidden_size;
3540            embed_tokens + lm_head + norm
3541        };
3542        Ok(elems * dtype.size_in_bytes())
3543    }
3544
3545    fn layer_sizes_in_bytes(
3546        &self,
3547        config: &str,
3548        dtype: DType,
3549        weight_pack_factor: usize,
3550        _matformer_config: Option<&MatformerSliceConfig>,
3551    ) -> Result<Vec<usize>> {
3552        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3553
3554        let mut layer_sizes_in_bytes = Vec::new();
3555        for layer_idx in 0..cfg.num_hidden_layers {
3556            let input_layernorm = cfg.hidden_size;
3557            let post_attention_layernorm = cfg.hidden_size;
3558
3559            let size_in = cfg.hidden_size;
3560            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3561            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3562            let q_proj = size_in * size_q / weight_pack_factor;
3563            let k_proj = size_in * size_kv / weight_pack_factor;
3564            let v_proj = size_in * size_kv / weight_pack_factor;
3565            let o_proj = size_q * size_in / weight_pack_factor;
3566
3567            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3568                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3569            {
3570                let gate_size = cfg.hidden_size * cfg.num_experts;
3571                let expert_size = {
3572                    let h_size = cfg.hidden_size;
3573                    let i_size = cfg.moe_intermediate_size;
3574                    let gate_proj = h_size * i_size / weight_pack_factor;
3575                    let up_proj = h_size * i_size / weight_pack_factor;
3576                    let down_proj = i_size * h_size / weight_pack_factor;
3577                    gate_proj + up_proj + down_proj
3578                };
3579                expert_size * cfg.num_experts + gate_size
3580            } else {
3581                let h_size = cfg.hidden_size;
3582                let i_size = cfg.intermediate_size;
3583                let gate_proj = h_size * i_size / weight_pack_factor;
3584                let up_proj = h_size * i_size / weight_pack_factor;
3585                let down_proj = i_size * h_size / weight_pack_factor;
3586                gate_proj + up_proj + down_proj
3587            };
3588
3589            let q_norm = cfg.head_dim();
3590            let k_norm = cfg.head_dim();
3591
3592            let size_elems = input_layernorm
3593                + post_attention_layernorm
3594                + q_proj
3595                + k_proj
3596                + v_proj
3597                + o_proj
3598                + mlp_size
3599                + q_norm
3600                + k_norm;
3601
3602            let size_in_bytes = size_elems * dtype.size_in_bytes();
3603            layer_sizes_in_bytes.push(size_in_bytes);
3604        }
3605
3606        Ok(layer_sizes_in_bytes)
3607    }
3608
3609    fn num_layers(&self, config: &str) -> Result<usize> {
3610        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3611        Ok(cfg.num_hidden_layers)
3612    }
3613
3614    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3615        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3616
3617        let cfg = ModelConfigMetadata {
3618            max_seq_len: cfg.max_position_embeddings,
3619            num_layers: cfg.num_hidden_layers,
3620            hidden_size: cfg.hidden_size,
3621            num_kv_heads: cfg.num_key_value_heads,
3622            num_attn_heads: cfg.num_attention_heads,
3623            sliding_window: cfg.sliding_window,
3624            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3625            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3626        };
3627
3628        Ok(Box::new(cfg))
3629    }
3630}
3631
3632// ======================== SmolLm3 loader
3633
3634/// [`NormalLoader`] for a SmolLm3 model.
3635///
3636/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3637pub struct SmolLm3Loader;
3638
3639impl NormalModelLoader for SmolLm3Loader {
3640    fn load(
3641        &self,
3642        config: &str,
3643        vb: ShardedVarBuilder,
3644        normal_loading_metadata: NormalLoadingMetadata,
3645        attention_mechanism: AttentionImplementation,
3646    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3647        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3648
3649        Ok(Box::new(models::smollm3::SmolLm3::new(
3650            &cfg,
3651            vb,
3652            self.is_gptx(config)?,
3653            normal_loading_metadata,
3654            attention_mechanism,
3655        )?))
3656    }
3657    fn load_xlora(
3658        &self,
3659        _config: &str,
3660        _vb: ShardedVarBuilder,
3661        _lora_config: &[((String, String), LoraConfig)],
3662        _xlora_config: Option<XLoraConfig>,
3663        _xlora_ordering: Ordering,
3664        _normal_loading_metadata: NormalLoadingMetadata,
3665        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3666    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3667        todo!()
3668    }
3669    fn is_gptx(&self, _: &str) -> Result<bool> {
3670        Ok(true)
3671    }
3672    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3673        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3674        Ok(Box::new(cfg))
3675    }
3676}
3677
3678impl IsqModelLoader for SmolLm3Loader {
3679    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3680        Ok(vec![
3681            Regex::new(r"lm_head\.(weight|bias)$")?,
3682            // Attention
3683            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3684            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3685            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3686            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3687            // MLP
3688            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3689            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3690            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3691        ])
3692    }
3693    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3694        self.isq_layer_regexes(config)
3695    }
3696}
3697
3698impl DeviceMappedModelLoader for SmolLm3Loader {
3699    fn mapped_max_act_size_elems(
3700        &self,
3701        config: &str,
3702        params: &AutoDeviceMapParams,
3703    ) -> Result<usize> {
3704        let AutoDeviceMapParams::Text {
3705            max_seq_len,
3706            max_batch_size,
3707        } = params
3708        else {
3709            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3710        };
3711
3712        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3713
3714        Ok(
3715            max_batch_size
3716                * cfg.num_attention_heads
3717                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3718        )
3719    }
3720    fn non_mapped_max_act_size_elems(
3721        &self,
3722        _config: &str,
3723        _params: &AutoDeviceMapParams,
3724    ) -> Result<usize> {
3725        Ok(0)
3726    }
3727
3728    fn non_mapped_size_in_bytes(
3729        &self,
3730        config: &str,
3731        dtype: DType,
3732        weight_pack_factor: usize,
3733        _matformer_config: Option<&MatformerSliceConfig>,
3734    ) -> Result<usize> {
3735        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3736
3737        let elems = {
3738            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3739            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3740            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3741                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3742            } else {
3743                0
3744            };
3745            let norm = cfg.hidden_size;
3746            embed_tokens + lm_head + norm
3747        };
3748        Ok(elems * dtype.size_in_bytes())
3749    }
3750
3751    fn layer_sizes_in_bytes(
3752        &self,
3753        config: &str,
3754        dtype: DType,
3755        weight_pack_factor: usize,
3756        _matformer_config: Option<&MatformerSliceConfig>,
3757    ) -> Result<Vec<usize>> {
3758        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3759
3760        let per_layer_elems = {
3761            let input_layernorm = cfg.hidden_size;
3762            let post_attention_layernorm = cfg.hidden_size;
3763
3764            let size_in = cfg.hidden_size;
3765            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3766            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3767            let q_proj = size_in * size_q / weight_pack_factor;
3768            let k_proj = size_in * size_kv / weight_pack_factor;
3769            let v_proj = size_in * size_kv / weight_pack_factor;
3770            let o_proj = size_q * size_in / weight_pack_factor;
3771
3772            let h_size = cfg.hidden_size;
3773            let i_size = cfg.intermediate_size;
3774            let gate_proj = h_size * i_size / weight_pack_factor;
3775            let up_proj = h_size * i_size / weight_pack_factor;
3776            let down_proj = i_size * h_size / weight_pack_factor;
3777
3778            input_layernorm
3779                + post_attention_layernorm
3780                + q_proj
3781                + k_proj
3782                + v_proj
3783                + o_proj
3784                + gate_proj
3785                + up_proj
3786                + down_proj
3787        };
3788        Ok(vec![
3789            per_layer_elems * dtype.size_in_bytes();
3790            cfg.num_hidden_layers
3791        ])
3792    }
3793
3794    fn num_layers(&self, config: &str) -> Result<usize> {
3795        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3796
3797        Ok(cfg.num_hidden_layers)
3798    }
3799    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3800        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3801
3802        let cfg = ModelConfigMetadata {
3803            max_seq_len: cfg.max_position_embeddings,
3804            num_layers: cfg.num_hidden_layers,
3805            hidden_size: cfg.hidden_size,
3806            num_kv_heads: cfg.num_key_value_heads,
3807            num_attn_heads: cfg.num_attention_heads,
3808            sliding_window: None,
3809            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3810            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3811        };
3812
3813        Ok(Box::new(cfg))
3814    }
3815}
3816
3817// ======================== GraniteMoeHybrid loader
3818
3819/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
3820///
3821/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3822pub struct GraniteMoeHybridLoader;
3823
3824impl NormalModelLoader for GraniteMoeHybridLoader {
3825    fn load(
3826        &self,
3827        config: &str,
3828        vb: ShardedVarBuilder,
3829        normal_loading_metadata: NormalLoadingMetadata,
3830        attention_mechanism: AttentionImplementation,
3831    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3832        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3833
3834        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
3835            &cfg,
3836            vb,
3837            self.is_gptx(config)?,
3838            normal_loading_metadata,
3839            attention_mechanism,
3840        )?))
3841    }
3842    fn load_xlora(
3843        &self,
3844        _config: &str,
3845        _vb: ShardedVarBuilder,
3846        _lora_config: &[((String, String), LoraConfig)],
3847        _xlora_config: Option<XLoraConfig>,
3848        _xlora_ordering: Ordering,
3849        _normal_loading_metadata: NormalLoadingMetadata,
3850        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3851    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3852        todo!()
3853    }
3854    fn is_gptx(&self, _: &str) -> Result<bool> {
3855        Ok(true)
3856    }
3857    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3858        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3859        Ok(Box::new(cfg))
3860    }
3861}
3862
3863impl IsqModelLoader for GraniteMoeHybridLoader {
3864    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3865        Ok(vec![
3866            Regex::new(r"lm_head\.(weight|bias)$")?,
3867            // Attention
3868            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3869            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3870            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3871            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3872            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
3873            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
3874            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
3875        ])
3876    }
3877    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3878        self.isq_layer_regexes(config)
3879    }
3880}
3881
3882impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
3883    fn mapped_max_act_size_elems(
3884        &self,
3885        config: &str,
3886        params: &AutoDeviceMapParams,
3887    ) -> Result<usize> {
3888        let AutoDeviceMapParams::Text {
3889            max_seq_len,
3890            max_batch_size,
3891        } = params
3892        else {
3893            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3894        };
3895
3896        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3897
3898        Ok(
3899            max_batch_size
3900                * cfg.num_attention_heads
3901                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3902        )
3903    }
3904    fn non_mapped_max_act_size_elems(
3905        &self,
3906        _config: &str,
3907        _params: &AutoDeviceMapParams,
3908    ) -> Result<usize> {
3909        Ok(0)
3910    }
3911
3912    fn non_mapped_size_in_bytes(
3913        &self,
3914        config: &str,
3915        dtype: DType,
3916        weight_pack_factor: usize,
3917        _matformer_config: Option<&MatformerSliceConfig>,
3918    ) -> Result<usize> {
3919        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3920
3921        let elems = {
3922            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3923            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3924            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3925                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3926            } else {
3927                0
3928            };
3929            let norm = cfg.hidden_size;
3930            embed_tokens + lm_head + norm
3931        };
3932        Ok(elems * dtype.size_in_bytes())
3933    }
3934
3935    fn layer_sizes_in_bytes(
3936        &self,
3937        config: &str,
3938        dtype: DType,
3939        weight_pack_factor: usize,
3940        _matformer_config: Option<&MatformerSliceConfig>,
3941    ) -> Result<Vec<usize>> {
3942        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3943
3944        let per_layer_elems = {
3945            let input_layernorm = cfg.hidden_size;
3946            let post_attention_layernorm = cfg.hidden_size;
3947
3948            let size_in = cfg.hidden_size;
3949            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3950            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
3951            let q_proj = size_in * size_q / weight_pack_factor;
3952            let k_proj = size_in * size_kv / weight_pack_factor;
3953            let v_proj = size_in * size_kv / weight_pack_factor;
3954            let o_proj = size_q * size_in / weight_pack_factor;
3955
3956            let h_size = cfg.hidden_size;
3957            let shared_i_size = cfg.shared_intermediate_size();
3958            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
3959            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
3960            let output_linear = shared_i_size * h_size / weight_pack_factor;
3961
3962            input_layernorm
3963                + post_attention_layernorm
3964                + q_proj
3965                + k_proj
3966                + v_proj
3967                + o_proj
3968                + input_linear
3969                + output_linear
3970        };
3971        Ok(vec![
3972            per_layer_elems * dtype.size_in_bytes();
3973            cfg.num_hidden_layers
3974        ])
3975    }
3976
3977    fn num_layers(&self, config: &str) -> Result<usize> {
3978        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3979
3980        Ok(cfg.num_hidden_layers)
3981    }
3982    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3983        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3984
3985        let cfg = ModelConfigMetadata {
3986            max_seq_len: cfg.max_position_embeddings,
3987            num_layers: cfg.num_hidden_layers,
3988            hidden_size: cfg.hidden_size,
3989            num_kv_heads: cfg.num_key_value_heads(),
3990            num_attn_heads: cfg.num_attention_heads,
3991            sliding_window: None,
3992            k_head_dim: cfg.head_dim(),
3993            v_head_dim: cfg.head_dim(),
3994        };
3995
3996        Ok(Box::new(cfg))
3997    }
3998}