mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    amoe::AnyMoeBaseModelMixin,
10    device_map::DeviceMapper,
11    lora::{LoraConfig, Ordering},
12    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
13    pipeline::{
14        isq::IsqModelLoader,
15        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16        EitherCache, IsqModel,
17    },
18    utils::varbuilder_utils::DeviceForLoadTensor,
19    xlora_models::NonGranularState,
20};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use mistralrs_quant::log::once_log_info;
24
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27#[cfg(feature = "pyo3_macros")]
28use pyo3::pyclass;
29
30use regex::Regex;
31use serde::Deserialize;
32
33use crate::{
34    models,
35    xlora_models::{self, XLoraConfig},
36};
37
38use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
39
40pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
41    #[allow(clippy::too_many_arguments)]
42    fn forward(
43        &self,
44        input_ids: &Tensor,
45        seqlen_offsets: &[usize],
46        context_lens: Vec<(usize, usize)>,
47        position_ids: Vec<usize>,
48        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
49        flash_params: &FlashParams,
50    ) -> candle_core::Result<Tensor>;
51    #[allow(clippy::too_many_arguments)]
52    fn xlora_forward(
53        &self,
54        input_ids: &Tensor,
55        input_ids_full: &Tensor,
56        seqlen_offsets: &[usize],
57        seqlen_offsets_full: &[usize],
58        no_kv_cache: bool,
59        non_granular_state: &Option<NonGranularState>,
60        context_lens: Vec<(usize, usize)>,
61        position_ids: Vec<usize>,
62        flash_params: &FlashParams,
63        flash_params_full: &FlashParams,
64    ) -> candle_core::Result<Tensor>;
65    fn is_xlora(&self) -> bool;
66    fn device(&self) -> &Device;
67    fn cache(&self) -> &EitherCache;
68    fn cache_mut(&mut self) -> &mut EitherCache;
69    fn max_seq_len(&self) -> usize;
70    fn config(&self) -> &ModelConfigMetadata;
71}
72
73/// Metadata for loading a model with ISQ or device mapping.
74pub struct NormalLoadingMetadata {
75    // Device mapping metadata which can be used to construct a concrete device mapper
76    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
77    // Flag to check if loading in ISQ
78    pub loading_isq: bool,
79    // Device mapping target device (the one that is not the cpu)
80    pub real_device: Device,
81    // MultiProgress support for parallelized loading
82    pub multi_progress: Arc<MultiProgress>,
83}
84
85pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
86    fn load(
87        &self,
88        config: &str,
89        vb: ShardedVarBuilder,
90        normal_loading_metadata: NormalLoadingMetadata,
91        attention_mechanism: AttentionImplementation,
92    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
93    #[allow(clippy::too_many_arguments)]
94    fn load_xlora(
95        &self,
96        config: &str,
97        vb: ShardedVarBuilder,
98        lora_config: &[((String, String), LoraConfig)],
99        xlora_config: Option<XLoraConfig>,
100        xlora_ordering: Ordering,
101        normal_loading_metadata: NormalLoadingMetadata,
102        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
103    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
104    fn is_gptx(&self, config: &str) -> Result<bool>;
105    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
106        Ok(true)
107    }
108    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
109    fn get_device_for_tensor(
110        &self,
111        config: &str,
112        _mapper: &dyn DeviceMapper,
113        loading_isq: bool,
114    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115        if loading_isq {
116            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117        } else {
118            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119            let num_layers = self.model_config(config)?.num_layers();
120            let closure = move |name: String| {
121                if let Some(captures) = re.captures(&name) {
122                    captures
123                        .get(1)
124                        .and_then(|m| m.as_str().parse::<usize>().ok())
125                        .map(|l| l.min(num_layers))
126                        .map(DeviceForLoadTensor::Idx)
127                        .unwrap_or(DeviceForLoadTensor::Base)
128                } else {
129                    DeviceForLoadTensor::Base
130                }
131            };
132
133            Ok(Arc::new(closure))
134        }
135    }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140/// The architecture to load the normal model as.
141pub enum NormalLoaderType {
142    #[serde(rename = "mistral")]
143    Mistral,
144    #[serde(rename = "gemma")]
145    Gemma,
146    #[serde(rename = "mixtral")]
147    Mixtral,
148    #[serde(rename = "llama")]
149    Llama,
150    #[serde(rename = "phi2")]
151    Phi2,
152    #[serde(rename = "phi3")]
153    Phi3,
154    #[serde(rename = "qwen2")]
155    Qwen2,
156    #[serde(rename = "gemma2")]
157    Gemma2,
158    #[serde(rename = "starcoder2")]
159    Starcoder2,
160    #[serde(rename = "phi3.5moe")]
161    Phi3_5MoE,
162    #[serde(rename = "deepseekv2")]
163    DeepSeekV2,
164    #[serde(rename = "deepseekv3")]
165    DeepSeekV3,
166    #[serde(rename = "qwen3")]
167    Qwen3,
168    #[serde(rename = "glm4")]
169    GLM4,
170    #[serde(rename = "qwen3moe")]
171    Qwen3Moe,
172    #[serde(rename = "smollm3")]
173    SmolLm3,
174}
175
176// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
177impl NormalLoaderType {
178    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179        match name {
180            "MistralForCausalLM" => Ok(Self::Mistral),
181            "MixtralForCausalLM" => Ok(Self::Mixtral),
182            "GemmaForCausalLM" => Ok(Self::Gemma),
183            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
184            "PhiForCausalLM" => Ok(Self::Phi2),
185            "Phi3ForCausalLM" => Ok(Self::Phi3),
186            "LlamaForCausalLM" => Ok(Self::Llama),
187            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
188            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
189            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
190            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
191            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
192            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
193            "Glm4ForCausalLM" => Ok(Self::GLM4),
194            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
195            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
196            other => anyhow::bail!(
197                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
198            ),
199        }
200    }
201}
202
203impl FromStr for NormalLoaderType {
204    type Err = String;
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        match s {
207            "mistral" => Ok(Self::Mistral),
208            "gemma" => Ok(Self::Gemma),
209            "mixtral" => Ok(Self::Mixtral),
210            "llama" => Ok(Self::Llama),
211            "phi2" => Ok(Self::Phi2),
212            "phi3" => Ok(Self::Phi3),
213            "qwen2" => Ok(Self::Qwen2),
214            "gemma2" => Ok(Self::Gemma2),
215            "starcoder2" => Ok(Self::Starcoder2),
216            "phi3.5moe" => Ok(Self::Phi3_5MoE),
217            "deepseekv2" => Ok(Self::DeepSeekV2),
218            "deepseekv3" => Ok(Self::DeepSeekV3),
219            "qwen3" => Ok(Self::Qwen3),
220            "glm4" => Ok(Self::GLM4),
221            "qwen3moe" => Ok(Self::Qwen3Moe),
222            "smollm3" => Ok(Self::SmolLm3),
223            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
224        }
225    }
226}
227
228impl Display for NormalLoaderType {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        match self {
231            Self::Gemma => write!(f, "gemma"),
232            Self::Gemma2 => write!(f, "gemma2"),
233            Self::Llama => write!(f, "llama"),
234            Self::Mistral => write!(f, "mistral"),
235            Self::Mixtral => write!(f, "mixtral"),
236            Self::Phi2 => write!(f, "phi2"),
237            Self::Phi3 => write!(f, "phi3"),
238            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
239            Self::Qwen2 => write!(f, "qwen2"),
240            Self::Starcoder2 => write!(f, "starcoder2"),
241            Self::DeepSeekV2 => write!(f, "deepseekv2"),
242            Self::DeepSeekV3 => write!(f, "deepseekv3"),
243            Self::Qwen3 => write!(f, "qwen3"),
244            Self::GLM4 => write!(f, "glm4"),
245            Self::Qwen3Moe => write!(f, "qwen3moe"),
246            Self::SmolLm3 => write!(f, "smollm3"),
247        }
248    }
249}
250
251macro_rules! bias_if {
252    ($cond:expr, $size:expr) => {
253        if $cond {
254            $size
255        } else {
256            0
257        }
258    };
259}
260
261/// Load a model based on the Hugging Face Transformers -CausalLM model class
262pub struct AutoNormalLoader;
263
264#[derive(Deserialize)]
265struct AutoNormalLoaderConfig {
266    architectures: Vec<String>,
267}
268
269impl AutoNormalLoader {
270    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
271        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
272        if auto_cfg.architectures.len() != 1 {
273            anyhow::bail!("Expected to have one name for `architectures` config field.")
274        }
275
276        let name = &auto_cfg.architectures[0];
277
278        let tp = NormalLoaderType::from_causal_lm_name(name)?;
279
280        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
281
282        match tp {
283            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
284            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
285            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
286            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
287            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
288            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
289            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
290            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
291            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
292            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
293            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
294            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
295            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
296            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
297            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
298            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
299        }
300    }
301}
302
303impl NormalModelLoader for AutoNormalLoader {
304    fn load(
305        &self,
306        config: &str,
307        vb: ShardedVarBuilder,
308        normal_loading_metadata: NormalLoadingMetadata,
309        attention_mechanism: AttentionImplementation,
310    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
311        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312    }
313    fn load_xlora(
314        &self,
315        config: &str,
316        vb: ShardedVarBuilder,
317        lora_config: &[((String, String), LoraConfig)],
318        xlora_config: Option<XLoraConfig>,
319        xlora_ordering: Ordering,
320        normal_loading_metadata: NormalLoadingMetadata,
321        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
322    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
323        Self::get_loader(config)?.load_xlora(
324            config,
325            vb,
326            lora_config,
327            xlora_config,
328            xlora_ordering,
329            normal_loading_metadata,
330            preload_adapters,
331        )
332    }
333    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
334        Self::get_loader(config)?.get_config_repr(config)
335    }
336    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
337        Self::get_loader(config)?.supports_paged_attention(config)
338    }
339    fn is_gptx(&self, config: &str) -> Result<bool> {
340        Self::get_loader(config)?.is_gptx(config)
341    }
342}
343
344impl IsqModelLoader for AutoNormalLoader {
345    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
346        Self::get_loader(config)?.immediate_isq_predicates(config)
347    }
348    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
349        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
350    }
351    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
352        Self::get_loader(config)?.isq_layer_regexes(config)
353    }
354    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
355        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
356    }
357}
358
359impl DeviceMappedModelLoader for AutoNormalLoader {
360    fn non_mapped_size_in_bytes(
361        &self,
362        config: &str,
363        dtype: DType,
364        weight_pack_factor: usize,
365    ) -> Result<usize> {
366        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
367    }
368    fn num_layers(&self, config: &str) -> Result<usize> {
369        Self::get_loader(config)?.num_layers(config)
370    }
371    fn layer_sizes_in_bytes(
372        &self,
373        config: &str,
374        dtype: DType,
375        weight_pack_factor: usize,
376    ) -> Result<Vec<usize>> {
377        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
378    }
379    fn mapped_max_act_size_elems(
380        &self,
381        config: &str,
382        params: &super::AutoDeviceMapParams,
383        prompt_chunksize: usize,
384    ) -> Result<usize> {
385        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
386    }
387    fn non_mapped_max_act_size_elems(
388        &self,
389        _config: &str,
390        _params: &AutoDeviceMapParams,
391    ) -> Result<usize> {
392        Ok(0)
393    }
394    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
395        Self::get_loader(config)?.model_config(config)
396    }
397}
398
399// ======================== Mistral loader
400
401pub struct MistralLoader;
402
403impl NormalModelLoader for MistralLoader {
404    fn load(
405        &self,
406        config: &str,
407        vb: ShardedVarBuilder,
408        normal_loading_metadata: NormalLoadingMetadata,
409        attention_mechanism: AttentionImplementation,
410    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
411        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
412        Ok(Box::new(models::mistral::Model::new(
413            &cfg,
414            vb,
415            self.is_gptx(config)?,
416            normal_loading_metadata,
417            attention_mechanism,
418        )?))
419    }
420    fn load_xlora(
421        &self,
422        config: &str,
423        vb: ShardedVarBuilder,
424        lora_config: &[((String, String), LoraConfig)],
425        xlora_config: Option<XLoraConfig>,
426        xlora_ordering: Ordering,
427        normal_loading_metadata: NormalLoadingMetadata,
428        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
429    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
430        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
431        Ok(Box::new(xlora_models::XLoraMistral::new(
432            &cfg,
433            vb,
434            lora_config,
435            xlora_config,
436            xlora_ordering,
437            self.is_gptx(config)?,
438            normal_loading_metadata,
439            preload_adapters,
440        )?))
441    }
442    fn is_gptx(&self, _: &str) -> Result<bool> {
443        Ok(true)
444    }
445    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
446        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
447        Ok(Box::new(cfg))
448    }
449}
450
451impl IsqModelLoader for MistralLoader {
452    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
453        Ok(vec![
454            Regex::new(r"lm_head\.(weight|bias)$")?,
455            // Attention
456            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
457            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
458            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
459            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
460            // MLP
461            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
462            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
463            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
464        ])
465    }
466    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
467        self.isq_layer_regexes(config)
468    }
469}
470
471impl DeviceMappedModelLoader for MistralLoader {
472    fn mapped_max_act_size_elems(
473        &self,
474        config: &str,
475        params: &AutoDeviceMapParams,
476        prompt_chunksize: usize,
477    ) -> Result<usize> {
478        let AutoDeviceMapParams::Text {
479            max_seq_len: _,
480            max_batch_size,
481        } = params
482        else {
483            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
484        };
485
486        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
487
488        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
489    }
490    fn non_mapped_max_act_size_elems(
491        &self,
492        _config: &str,
493        _params: &AutoDeviceMapParams,
494    ) -> Result<usize> {
495        Ok(0)
496    }
497
498    fn non_mapped_size_in_bytes(
499        &self,
500        config: &str,
501        dtype: DType,
502        weight_pack_factor: usize,
503    ) -> Result<usize> {
504        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
505
506        let elems = {
507            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
508            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
509            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
510                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
511            } else {
512                0
513            };
514            let norm = cfg.hidden_size;
515            embed_tokens + lm_head + norm
516        };
517        Ok(elems * dtype.size_in_bytes())
518    }
519
520    fn layer_sizes_in_bytes(
521        &self,
522        config: &str,
523        dtype: DType,
524        weight_pack_factor: usize,
525    ) -> Result<Vec<usize>> {
526        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
527
528        let per_layer_elems = {
529            let input_layernorm = cfg.hidden_size;
530            let post_attention_layernorm = cfg.hidden_size;
531
532            let size_in = cfg.hidden_size;
533            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
534            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
535            let q_proj = size_in * size_q / weight_pack_factor;
536            let k_proj = size_in * size_kv / weight_pack_factor;
537            let v_proj = size_in * size_kv / weight_pack_factor;
538            let o_proj = size_q * size_in / weight_pack_factor;
539
540            let h_size = cfg.hidden_size;
541            let i_size = cfg.intermediate_size;
542            let gate_proj = h_size * i_size / weight_pack_factor;
543            let up_proj = h_size * i_size / weight_pack_factor;
544            let down_proj = i_size * h_size / weight_pack_factor;
545
546            input_layernorm
547                + post_attention_layernorm
548                + q_proj
549                + k_proj
550                + v_proj
551                + o_proj
552                + gate_proj
553                + up_proj
554                + down_proj
555        };
556        Ok(vec![
557            per_layer_elems * dtype.size_in_bytes();
558            cfg.num_hidden_layers
559        ])
560    }
561
562    fn num_layers(&self, config: &str) -> Result<usize> {
563        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
564        Ok(cfg.num_hidden_layers)
565    }
566
567    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
568        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
569
570        let cfg = ModelConfigMetadata {
571            max_seq_len: cfg.max_position_embeddings,
572            num_layers: cfg.num_hidden_layers,
573            hidden_size: cfg.hidden_size,
574            num_kv_heads: cfg.num_key_value_heads,
575            num_attn_heads: cfg.num_attention_heads,
576            sliding_window: cfg.sliding_window,
577            k_head_dim: cfg.head_dim(),
578            v_head_dim: cfg.head_dim(),
579        };
580
581        Ok(Box::new(cfg))
582    }
583}
584
585// ======================== Gemma loader
586
587/// [`NormalLoader`] for a Gemma model.
588///
589/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
590pub struct GemmaLoader;
591
592impl NormalModelLoader for GemmaLoader {
593    fn load(
594        &self,
595        config: &str,
596        vb: ShardedVarBuilder,
597        normal_loading_metadata: NormalLoadingMetadata,
598        attention_mechanism: AttentionImplementation,
599    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
600        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
601
602        Ok(Box::new(models::gemma::Model::new(
603            &cfg,
604            vb,
605            self.is_gptx(config)?,
606            normal_loading_metadata,
607            attention_mechanism,
608        )?))
609    }
610    fn load_xlora(
611        &self,
612        config: &str,
613        vb: ShardedVarBuilder,
614        lora_config: &[((String, String), LoraConfig)],
615        xlora_config: Option<XLoraConfig>,
616        xlora_ordering: Ordering,
617        normal_loading_metadata: NormalLoadingMetadata,
618        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
619    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
620        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
621
622        Ok(Box::new(xlora_models::XLoraGemma::new(
623            &cfg,
624            vb,
625            lora_config,
626            xlora_config,
627            xlora_ordering,
628            self.is_gptx(config)?,
629            normal_loading_metadata,
630            preload_adapters,
631        )?))
632    }
633    fn is_gptx(&self, _: &str) -> Result<bool> {
634        Ok(true)
635    }
636    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
637        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
638        Ok(Box::new(cfg))
639    }
640}
641
642impl IsqModelLoader for GemmaLoader {
643    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
644        Ok(vec![
645            Regex::new(r"lm_head\.(weight|bias)$")?,
646            // Attention
647            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
648            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
649            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
650            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
651            // MLP
652            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
653            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
654            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
655        ])
656    }
657    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
658        self.isq_layer_regexes(config)
659    }
660}
661
662impl DeviceMappedModelLoader for GemmaLoader {
663    fn mapped_max_act_size_elems(
664        &self,
665        config: &str,
666        params: &AutoDeviceMapParams,
667        prompt_chunksize: usize,
668    ) -> Result<usize> {
669        let AutoDeviceMapParams::Text {
670            max_seq_len: _,
671            max_batch_size,
672        } = params
673        else {
674            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
675        };
676
677        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
678
679        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
680    }
681    fn non_mapped_max_act_size_elems(
682        &self,
683        _config: &str,
684        _params: &AutoDeviceMapParams,
685    ) -> Result<usize> {
686        Ok(0)
687    }
688
689    fn non_mapped_size_in_bytes(
690        &self,
691        config: &str,
692        dtype: DType,
693        weight_pack_factor: usize,
694    ) -> Result<usize> {
695        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
696
697        let elems = {
698            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
699            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
700            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
701                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
702            } else {
703                0
704            };
705            let norm = cfg.hidden_size;
706            embed_tokens + lm_head + norm
707        };
708        Ok(elems * dtype.size_in_bytes())
709    }
710
711    fn layer_sizes_in_bytes(
712        &self,
713        config: &str,
714        dtype: DType,
715        weight_pack_factor: usize,
716    ) -> Result<Vec<usize>> {
717        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
718
719        let per_layer_elems = {
720            let input_layernorm = cfg.hidden_size;
721            let post_attention_layernorm = cfg.hidden_size;
722
723            let size_in = cfg.hidden_size;
724            let size_q = cfg.head_dim * cfg.num_attention_heads;
725            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
726            let q_proj =
727                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
728            let k_proj =
729                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
730            let v_proj =
731                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
732            let o_proj =
733                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
734
735            let h_size = cfg.hidden_size;
736            let i_size = cfg.intermediate_size;
737            let gate_proj = h_size * i_size / weight_pack_factor;
738            let up_proj = h_size * i_size / weight_pack_factor;
739            let down_proj = i_size * h_size / weight_pack_factor;
740
741            input_layernorm
742                + post_attention_layernorm
743                + q_proj
744                + k_proj
745                + v_proj
746                + o_proj
747                + gate_proj
748                + up_proj
749                + down_proj
750        };
751        Ok(vec![
752            per_layer_elems * dtype.size_in_bytes();
753            cfg.num_hidden_layers
754        ])
755    }
756
757    fn num_layers(&self, config: &str) -> Result<usize> {
758        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
759        Ok(cfg.num_hidden_layers)
760    }
761
762    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
763        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
764
765        let cfg = ModelConfigMetadata {
766            max_seq_len: cfg.max_position_embeddings,
767            num_layers: cfg.num_hidden_layers,
768            hidden_size: cfg.hidden_size,
769            num_kv_heads: cfg.num_key_value_heads,
770            num_attn_heads: cfg.num_attention_heads,
771            sliding_window: None,
772            k_head_dim: cfg.head_dim,
773            v_head_dim: cfg.head_dim,
774        };
775
776        Ok(Box::new(cfg))
777    }
778}
779
780// ======================== Llama loader
781
782/// [`NormalLoader`] for a Llama model.
783///
784/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
785pub struct LlamaLoader;
786
787impl NormalModelLoader for LlamaLoader {
788    fn load(
789        &self,
790        config: &str,
791        vb: ShardedVarBuilder,
792        normal_loading_metadata: NormalLoadingMetadata,
793        attention_mechanism: AttentionImplementation,
794    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
795        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
796
797        Ok(Box::new(models::llama::Llama::new(
798            &cfg,
799            vb,
800            self.is_gptx(config)?,
801            normal_loading_metadata,
802            attention_mechanism,
803        )?))
804    }
805    fn load_xlora(
806        &self,
807        config: &str,
808        vb: ShardedVarBuilder,
809        lora_config: &[((String, String), LoraConfig)],
810        xlora_config: Option<XLoraConfig>,
811        xlora_ordering: Ordering,
812        normal_loading_metadata: NormalLoadingMetadata,
813        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
814    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
815        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
816
817        Ok(Box::new(xlora_models::XLoraLlama::new(
818            &cfg,
819            vb,
820            lora_config,
821            xlora_config,
822            xlora_ordering,
823            self.is_gptx(config)?,
824            normal_loading_metadata,
825            preload_adapters,
826        )?))
827    }
828    fn is_gptx(&self, _: &str) -> Result<bool> {
829        Ok(true)
830    }
831    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
832        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
833        Ok(Box::new(cfg))
834    }
835}
836
837impl IsqModelLoader for LlamaLoader {
838    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
839        Ok(vec![
840            Regex::new(r"lm_head\.(weight|bias)$")?,
841            // Attention
842            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
843            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
844            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
845            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
846            // MLP
847            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
848            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
849            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
850        ])
851    }
852    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
853        self.isq_layer_regexes(config)
854    }
855}
856
857impl DeviceMappedModelLoader for LlamaLoader {
858    fn mapped_max_act_size_elems(
859        &self,
860        config: &str,
861        params: &AutoDeviceMapParams,
862        prompt_chunksize: usize,
863    ) -> Result<usize> {
864        let AutoDeviceMapParams::Text {
865            max_seq_len: _,
866            max_batch_size,
867        } = params
868        else {
869            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
870        };
871
872        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
873
874        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
875    }
876    fn non_mapped_max_act_size_elems(
877        &self,
878        _config: &str,
879        _params: &AutoDeviceMapParams,
880    ) -> Result<usize> {
881        Ok(0)
882    }
883
884    fn non_mapped_size_in_bytes(
885        &self,
886        config: &str,
887        dtype: DType,
888        weight_pack_factor: usize,
889    ) -> Result<usize> {
890        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
891
892        let elems = {
893            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
894            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
895            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
896                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
897            } else {
898                0
899            };
900            let norm = cfg.hidden_size;
901            embed_tokens + lm_head + norm
902        };
903        Ok(elems * dtype.size_in_bytes())
904    }
905
906    fn layer_sizes_in_bytes(
907        &self,
908        config: &str,
909        dtype: DType,
910        weight_pack_factor: usize,
911    ) -> Result<Vec<usize>> {
912        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
913
914        let per_layer_elems = {
915            let input_layernorm = cfg.hidden_size;
916            let post_attention_layernorm = cfg.hidden_size;
917
918            let size_in = cfg.hidden_size;
919            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
920            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
921            let q_proj = size_in * size_q / weight_pack_factor;
922            let k_proj = size_in * size_kv / weight_pack_factor;
923            let v_proj = size_in * size_kv / weight_pack_factor;
924            let o_proj = size_q * size_in / weight_pack_factor;
925
926            let h_size = cfg.hidden_size;
927            let i_size = cfg.intermediate_size;
928            let gate_proj = h_size * i_size / weight_pack_factor;
929            let up_proj = h_size * i_size / weight_pack_factor;
930            let down_proj = i_size * h_size / weight_pack_factor;
931
932            input_layernorm
933                + post_attention_layernorm
934                + q_proj
935                + k_proj
936                + v_proj
937                + o_proj
938                + gate_proj
939                + up_proj
940                + down_proj
941        };
942        Ok(vec![
943            per_layer_elems * dtype.size_in_bytes();
944            cfg.num_hidden_layers
945        ])
946    }
947
948    fn num_layers(&self, config: &str) -> Result<usize> {
949        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
950
951        Ok(cfg.num_hidden_layers)
952    }
953    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
954        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956        let cfg = ModelConfigMetadata {
957            max_seq_len: cfg.max_position_embeddings,
958            num_layers: cfg.num_hidden_layers,
959            hidden_size: cfg.hidden_size,
960            num_kv_heads: cfg.num_key_value_heads,
961            num_attn_heads: cfg.num_attention_heads,
962            sliding_window: None,
963            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
964            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
965        };
966
967        Ok(Box::new(cfg))
968    }
969}
970
971// ======================== Mixtral loader
972
973pub struct MixtralLoader;
974
975impl NormalModelLoader for MixtralLoader {
976    fn load(
977        &self,
978        config: &str,
979        vb: ShardedVarBuilder,
980        normal_loading_metadata: NormalLoadingMetadata,
981        attention_mechanism: AttentionImplementation,
982    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
983        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
984
985        Ok(Box::new(models::mixtral::Model::new(
986            &cfg,
987            vb,
988            self.is_gptx(config)?,
989            normal_loading_metadata,
990            attention_mechanism,
991        )?))
992    }
993    fn load_xlora(
994        &self,
995        config: &str,
996        vb: ShardedVarBuilder,
997        lora_config: &[((String, String), LoraConfig)],
998        xlora_config: Option<XLoraConfig>,
999        xlora_ordering: Ordering,
1000        normal_loading_metadata: NormalLoadingMetadata,
1001        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1002    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1003        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1004
1005        Ok(Box::new(xlora_models::XLoraMixtral::new(
1006            &cfg,
1007            vb,
1008            lora_config,
1009            xlora_config,
1010            xlora_ordering,
1011            self.is_gptx(config)?,
1012            normal_loading_metadata,
1013            preload_adapters,
1014        )?))
1015    }
1016    fn is_gptx(&self, _: &str) -> Result<bool> {
1017        Ok(true)
1018    }
1019    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1020        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1021
1022        Ok(Box::new(cfg))
1023    }
1024}
1025
1026impl IsqModelLoader for MixtralLoader {
1027    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1028        Ok(vec![
1029            Regex::new(r"lm_head\.(weight|bias)$")?,
1030            // Attention
1031            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1032            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1033            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1034            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1035            // Experts
1036            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1037            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1038            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1039            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1040        ])
1041    }
1042    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1043        self.isq_layer_regexes(config)
1044    }
1045}
1046
1047impl DeviceMappedModelLoader for MixtralLoader {
1048    fn mapped_max_act_size_elems(
1049        &self,
1050        config: &str,
1051        params: &AutoDeviceMapParams,
1052        prompt_chunksize: usize,
1053    ) -> Result<usize> {
1054        let AutoDeviceMapParams::Text {
1055            max_seq_len: _,
1056            max_batch_size,
1057        } = params
1058        else {
1059            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1060        };
1061
1062        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1063
1064        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1065    }
1066    fn non_mapped_max_act_size_elems(
1067        &self,
1068        _config: &str,
1069        _params: &AutoDeviceMapParams,
1070    ) -> Result<usize> {
1071        Ok(0)
1072    }
1073
1074    fn non_mapped_size_in_bytes(
1075        &self,
1076        config: &str,
1077        dtype: DType,
1078        weight_pack_factor: usize,
1079    ) -> Result<usize> {
1080        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1081
1082        let elems = {
1083            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1084            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1085            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1086                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1087            } else {
1088                0
1089            };
1090            let norm = cfg.hidden_size;
1091            embed_tokens + lm_head + norm
1092        };
1093        Ok(elems * dtype.size_in_bytes())
1094    }
1095
1096    fn layer_sizes_in_bytes(
1097        &self,
1098        config: &str,
1099        dtype: DType,
1100        weight_pack_factor: usize,
1101    ) -> Result<Vec<usize>> {
1102        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1103
1104        let per_layer_elems = {
1105            let input_layernorm = cfg.hidden_size;
1106            let post_attention_layernorm = cfg.hidden_size;
1107
1108            let size_in = cfg.hidden_size;
1109            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1110            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1111            let q_proj = size_in * size_q / weight_pack_factor;
1112            let k_proj = size_in * size_kv / weight_pack_factor;
1113            let v_proj = size_in * size_kv / weight_pack_factor;
1114            let o_proj = size_q * size_in / weight_pack_factor;
1115
1116            let moe_block = {
1117                let gate = cfg.hidden_size * cfg.num_local_experts;
1118                // Assume quantizing weight pack factor
1119                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1120                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1121                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1122                gate + cfg.num_local_experts * w1
1123                    + cfg.num_local_experts * w2
1124                    + cfg.num_local_experts * w3
1125            };
1126
1127            input_layernorm
1128                + post_attention_layernorm
1129                + q_proj
1130                + k_proj
1131                + v_proj
1132                + o_proj
1133                + moe_block
1134        };
1135        Ok(vec![
1136            per_layer_elems * dtype.size_in_bytes();
1137            cfg.num_hidden_layers
1138        ])
1139    }
1140
1141    fn num_layers(&self, config: &str) -> Result<usize> {
1142        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1143
1144        Ok(cfg.num_hidden_layers)
1145    }
1146
1147    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1148        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1149
1150        let cfg = ModelConfigMetadata {
1151            max_seq_len: cfg.max_position_embeddings,
1152            num_layers: cfg.num_hidden_layers,
1153            hidden_size: cfg.hidden_size,
1154            num_kv_heads: cfg.num_key_value_heads,
1155            num_attn_heads: cfg.num_attention_heads,
1156            sliding_window: cfg.sliding_window,
1157            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1158            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1159        };
1160
1161        Ok(Box::new(cfg))
1162    }
1163}
1164
1165// ======================== Phi2 loader
1166
1167/// [`NormalLoader`] for a Phi 2 model.
1168///
1169/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1170pub struct Phi2Loader;
1171
1172impl NormalModelLoader for Phi2Loader {
1173    fn load(
1174        &self,
1175        config: &str,
1176        vb: ShardedVarBuilder,
1177        normal_loading_metadata: NormalLoadingMetadata,
1178        attention_mechanism: AttentionImplementation,
1179    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1180        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1181
1182        Ok(Box::new(models::phi2::Model::new(
1183            &cfg,
1184            vb,
1185            self.is_gptx(config)?,
1186            normal_loading_metadata,
1187            attention_mechanism,
1188        )?))
1189    }
1190    fn load_xlora(
1191        &self,
1192        config: &str,
1193        vb: ShardedVarBuilder,
1194        lora_config: &[((String, String), LoraConfig)],
1195        xlora_config: Option<XLoraConfig>,
1196        xlora_ordering: Ordering,
1197        normal_loading_metadata: NormalLoadingMetadata,
1198        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1199    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1200        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1201
1202        Ok(Box::new(xlora_models::XLoraPhi2::new(
1203            &cfg,
1204            vb,
1205            lora_config,
1206            xlora_config,
1207            xlora_ordering,
1208            self.is_gptx(config)?,
1209            normal_loading_metadata,
1210            preload_adapters,
1211        )?))
1212    }
1213    fn is_gptx(&self, _: &str) -> Result<bool> {
1214        Ok(true)
1215    }
1216    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1217        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1218
1219        Ok(Box::new(cfg))
1220    }
1221}
1222
1223impl IsqModelLoader for Phi2Loader {
1224    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1225        Ok(vec![
1226            Regex::new(r"lm_head\.(weight|bias)$")?,
1227            // Attention
1228            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1229            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1230            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1231            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1232            // MLP
1233            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1234            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1235        ])
1236    }
1237    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1238        self.isq_layer_regexes(config)
1239    }
1240}
1241
1242impl DeviceMappedModelLoader for Phi2Loader {
1243    fn mapped_max_act_size_elems(
1244        &self,
1245        config: &str,
1246        params: &AutoDeviceMapParams,
1247        prompt_chunksize: usize,
1248    ) -> Result<usize> {
1249        let AutoDeviceMapParams::Text {
1250            max_seq_len: _,
1251            max_batch_size,
1252        } = params
1253        else {
1254            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1255        };
1256
1257        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1258
1259        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1260    }
1261    fn non_mapped_max_act_size_elems(
1262        &self,
1263        _config: &str,
1264        _params: &AutoDeviceMapParams,
1265    ) -> Result<usize> {
1266        Ok(0)
1267    }
1268
1269    fn non_mapped_size_in_bytes(
1270        &self,
1271        config: &str,
1272        dtype: DType,
1273        weight_pack_factor: usize,
1274    ) -> Result<usize> {
1275        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1276
1277        let elems = {
1278            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1279            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1280            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1281                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1282            } else {
1283                0
1284            };
1285            let norm = cfg.hidden_size;
1286            embed_tokens + lm_head + norm
1287        };
1288        Ok(elems * dtype.size_in_bytes())
1289    }
1290
1291    fn layer_sizes_in_bytes(
1292        &self,
1293        config: &str,
1294        dtype: DType,
1295        weight_pack_factor: usize,
1296    ) -> Result<Vec<usize>> {
1297        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1298
1299        let per_layer_elems = {
1300            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1301
1302            let size_in = cfg.hidden_size;
1303            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1304            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1305            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1306            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1307            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1308            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1309            let (q_norm, k_norm) = if cfg.qk_layernorm {
1310                (cfg.head_dim(), cfg.head_dim())
1311            } else {
1312                (0, 0)
1313            };
1314
1315            let h_size = cfg.hidden_size;
1316            let i_size = cfg.intermediate_size;
1317            let fc1 = h_size * i_size / weight_pack_factor;
1318            let fc2 = h_size * i_size / weight_pack_factor;
1319
1320            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1321        };
1322        Ok(vec![
1323            per_layer_elems * dtype.size_in_bytes();
1324            cfg.num_hidden_layers
1325        ])
1326    }
1327
1328    fn num_layers(&self, config: &str) -> Result<usize> {
1329        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1330
1331        Ok(cfg.num_hidden_layers)
1332    }
1333
1334    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1335        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1336
1337        let cfg = ModelConfigMetadata {
1338            max_seq_len: cfg.max_position_embeddings,
1339            num_layers: cfg.num_hidden_layers,
1340            hidden_size: cfg.hidden_size,
1341            num_kv_heads: cfg.num_key_value_heads(),
1342            num_attn_heads: cfg.num_attention_heads,
1343            sliding_window: None,
1344            k_head_dim: cfg.head_dim(),
1345            v_head_dim: cfg.head_dim(),
1346        };
1347
1348        Ok(Box::new(cfg))
1349    }
1350}
1351
1352// ======================== Phi3 loader
1353
1354/// [`NormalLoader`] for a Phi 3 model.
1355///
1356/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1357pub struct Phi3Loader;
1358
1359impl NormalModelLoader for Phi3Loader {
1360    fn load(
1361        &self,
1362        config: &str,
1363        vb: ShardedVarBuilder,
1364        normal_loading_metadata: NormalLoadingMetadata,
1365        attention_mechanism: AttentionImplementation,
1366    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1367        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1368
1369        Ok(Box::new(models::phi3::Model::new(
1370            &cfg,
1371            vb,
1372            self.is_gptx(config)?,
1373            normal_loading_metadata,
1374            attention_mechanism,
1375        )?))
1376    }
1377    fn load_xlora(
1378        &self,
1379        config: &str,
1380        vb: ShardedVarBuilder,
1381        lora_config: &[((String, String), LoraConfig)],
1382        xlora_config: Option<XLoraConfig>,
1383        xlora_ordering: Ordering,
1384        normal_loading_metadata: NormalLoadingMetadata,
1385        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1386    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1387        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1388
1389        Ok(Box::new(xlora_models::XLoraPhi3::new(
1390            &cfg,
1391            vb,
1392            lora_config,
1393            xlora_config,
1394            xlora_ordering,
1395            self.is_gptx(config)?,
1396            normal_loading_metadata,
1397            preload_adapters,
1398        )?))
1399    }
1400    fn is_gptx(&self, _: &str) -> Result<bool> {
1401        Ok(true)
1402    }
1403    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1404        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1405
1406        Ok(Box::new(cfg))
1407    }
1408}
1409
1410impl IsqModelLoader for Phi3Loader {
1411    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1412        Ok(vec![
1413            Regex::new(r"lm_head\.(weight|bias)$")?,
1414            // Attention
1415            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1416            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1417            // MLP
1418            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1419            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1420            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1421        ])
1422    }
1423    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1424        self.isq_layer_regexes(config)
1425    }
1426}
1427
1428impl DeviceMappedModelLoader for Phi3Loader {
1429    fn mapped_max_act_size_elems(
1430        &self,
1431        config: &str,
1432        params: &AutoDeviceMapParams,
1433        prompt_chunksize: usize,
1434    ) -> Result<usize> {
1435        let AutoDeviceMapParams::Text {
1436            max_seq_len: _,
1437            max_batch_size,
1438        } = params
1439        else {
1440            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1441        };
1442
1443        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1444
1445        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1446    }
1447    fn non_mapped_max_act_size_elems(
1448        &self,
1449        _config: &str,
1450        _params: &AutoDeviceMapParams,
1451    ) -> Result<usize> {
1452        Ok(0)
1453    }
1454
1455    fn non_mapped_size_in_bytes(
1456        &self,
1457        config: &str,
1458        dtype: DType,
1459        weight_pack_factor: usize,
1460    ) -> Result<usize> {
1461        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1462
1463        let elems = {
1464            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1465            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1466            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1467                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1468            } else {
1469                0
1470            };
1471            let norm = cfg.hidden_size;
1472            embed_tokens + lm_head + norm
1473        };
1474        Ok(elems * dtype.size_in_bytes())
1475    }
1476
1477    fn layer_sizes_in_bytes(
1478        &self,
1479        config: &str,
1480        dtype: DType,
1481        weight_pack_factor: usize,
1482    ) -> Result<Vec<usize>> {
1483        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1484
1485        let per_layer_elems = {
1486            let input_layernorm = cfg.hidden_size;
1487            let post_attention_layernorm = cfg.hidden_size;
1488
1489            let size_in = cfg.hidden_size;
1490            let head_dim = cfg.head_dim();
1491            let op_size =
1492                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1493            let qkv_proj = size_in * op_size / weight_pack_factor;
1494            let o_proj =
1495                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1496
1497            let h_size = cfg.hidden_size;
1498            let i_size = cfg.intermediate_size;
1499            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1500            let down_proj = h_size * i_size / weight_pack_factor;
1501
1502            input_layernorm
1503                + post_attention_layernorm
1504                + qkv_proj
1505                + o_proj
1506                + gate_up_proj
1507                + down_proj
1508        };
1509        Ok(vec![
1510            per_layer_elems * dtype.size_in_bytes();
1511            cfg.num_hidden_layers
1512        ])
1513    }
1514
1515    fn num_layers(&self, config: &str) -> Result<usize> {
1516        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1517
1518        Ok(cfg.num_hidden_layers)
1519    }
1520
1521    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1522        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1523
1524        let cfg = ModelConfigMetadata {
1525            max_seq_len: cfg.max_position_embeddings,
1526            num_layers: cfg.num_hidden_layers,
1527            hidden_size: cfg.hidden_size,
1528            num_kv_heads: cfg.num_key_value_heads,
1529            num_attn_heads: cfg.num_attention_heads,
1530            sliding_window: cfg.sliding_window,
1531            k_head_dim: cfg.head_dim(),
1532            v_head_dim: cfg.head_dim(),
1533        };
1534
1535        Ok(Box::new(cfg))
1536    }
1537}
1538
1539// ======================== Qwen2 loader
1540
1541/// [`NormalLoader`] for a Qwen 2 model.
1542///
1543/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1544pub struct Qwen2Loader;
1545
1546impl NormalModelLoader for Qwen2Loader {
1547    fn load(
1548        &self,
1549        config: &str,
1550        vb: ShardedVarBuilder,
1551        normal_loading_metadata: NormalLoadingMetadata,
1552        attention_mechanism: AttentionImplementation,
1553    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1554        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1555
1556        Ok(Box::new(models::qwen2::Model::new(
1557            &cfg,
1558            vb,
1559            self.is_gptx(config)?,
1560            normal_loading_metadata,
1561            attention_mechanism,
1562        )?))
1563    }
1564    fn load_xlora(
1565        &self,
1566        _config: &str,
1567        _vb: ShardedVarBuilder,
1568        _lora_config: &[((String, String), LoraConfig)],
1569        _xlora_config: Option<XLoraConfig>,
1570        _xlora_ordering: Ordering,
1571        _normal_loading_metadata: NormalLoadingMetadata,
1572        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1573    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1574        todo!()
1575    }
1576    fn is_gptx(&self, _: &str) -> Result<bool> {
1577        Ok(true)
1578    }
1579    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1580        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1581
1582        Ok(Box::new(cfg))
1583    }
1584}
1585
1586impl IsqModelLoader for Qwen2Loader {
1587    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1588        Ok(vec![
1589            Regex::new(r"lm_head\.(weight|bias)$")?,
1590            // Attention
1591            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1592            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1593            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1594            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1595            // MLP
1596            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1597            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1598            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1599        ])
1600    }
1601    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1602        self.isq_layer_regexes(config)
1603    }
1604}
1605
1606impl DeviceMappedModelLoader for Qwen2Loader {
1607    fn mapped_max_act_size_elems(
1608        &self,
1609        config: &str,
1610        params: &AutoDeviceMapParams,
1611        prompt_chunksize: usize,
1612    ) -> Result<usize> {
1613        let AutoDeviceMapParams::Text {
1614            max_seq_len: _,
1615            max_batch_size,
1616        } = params
1617        else {
1618            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1619        };
1620
1621        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1622
1623        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1624    }
1625    fn non_mapped_max_act_size_elems(
1626        &self,
1627        _config: &str,
1628        _params: &AutoDeviceMapParams,
1629    ) -> Result<usize> {
1630        Ok(0)
1631    }
1632
1633    fn non_mapped_size_in_bytes(
1634        &self,
1635        config: &str,
1636        dtype: DType,
1637        weight_pack_factor: usize,
1638    ) -> Result<usize> {
1639        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1640
1641        let elems = {
1642            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1643            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1644            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1645                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1646            } else {
1647                0
1648            };
1649            let norm = cfg.hidden_size;
1650            embed_tokens + lm_head + norm
1651        };
1652        Ok(elems * dtype.size_in_bytes())
1653    }
1654
1655    fn layer_sizes_in_bytes(
1656        &self,
1657        config: &str,
1658        dtype: DType,
1659        weight_pack_factor: usize,
1660    ) -> Result<Vec<usize>> {
1661        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1662
1663        let per_layer_elems = {
1664            let input_layernorm = cfg.hidden_size;
1665            let post_attention_layernorm = cfg.hidden_size;
1666
1667            let size_in = cfg.hidden_size;
1668            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1669            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1670            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1671            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1672            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1673            let o_proj = size_q * size_in / weight_pack_factor;
1674
1675            let h_size = cfg.hidden_size;
1676            let i_size = cfg.intermediate_size;
1677            let gate_proj = h_size * i_size / weight_pack_factor;
1678            let up_proj = h_size * i_size / weight_pack_factor;
1679            let down_proj = i_size * h_size / weight_pack_factor;
1680
1681            input_layernorm
1682                + post_attention_layernorm
1683                + q_proj
1684                + k_proj
1685                + v_proj
1686                + o_proj
1687                + gate_proj
1688                + up_proj
1689                + down_proj
1690        };
1691        Ok(vec![
1692            per_layer_elems * dtype.size_in_bytes();
1693            cfg.num_hidden_layers
1694        ])
1695    }
1696
1697    fn num_layers(&self, config: &str) -> Result<usize> {
1698        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1699
1700        Ok(cfg.num_hidden_layers)
1701    }
1702
1703    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1704        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1705
1706        let cfg = ModelConfigMetadata {
1707            max_seq_len: cfg.max_position_embeddings,
1708            num_layers: cfg.num_hidden_layers,
1709            hidden_size: cfg.hidden_size,
1710            num_kv_heads: cfg.num_key_value_heads,
1711            num_attn_heads: cfg.num_attention_heads,
1712            sliding_window: cfg.sliding_window,
1713            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1714            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1715        };
1716
1717        Ok(Box::new(cfg))
1718    }
1719}
1720
1721// ======================== Gemma2 loader
1722
1723/// [`NormalLoader`] for a Gemma2 model.
1724///
1725/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1726pub struct Gemma2Loader;
1727
1728impl NormalModelLoader for Gemma2Loader {
1729    fn load(
1730        &self,
1731        config: &str,
1732        vb: ShardedVarBuilder,
1733        normal_loading_metadata: NormalLoadingMetadata,
1734        attention_mechanism: AttentionImplementation,
1735    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1736        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1737
1738        Ok(Box::new(models::gemma2::Model::new(
1739            &cfg,
1740            vb,
1741            self.is_gptx(config)?,
1742            normal_loading_metadata,
1743            attention_mechanism,
1744        )?))
1745    }
1746    fn load_xlora(
1747        &self,
1748        config: &str,
1749        vb: ShardedVarBuilder,
1750        lora_config: &[((String, String), LoraConfig)],
1751        xlora_config: Option<XLoraConfig>,
1752        xlora_ordering: Ordering,
1753        normal_loading_metadata: NormalLoadingMetadata,
1754        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1755    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1756        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1757
1758        Ok(Box::new(xlora_models::XLoraGemma2::new(
1759            &cfg,
1760            vb,
1761            lora_config,
1762            xlora_config,
1763            xlora_ordering,
1764            self.is_gptx(config)?,
1765            normal_loading_metadata,
1766            preload_adapters,
1767        )?))
1768    }
1769    fn is_gptx(&self, _: &str) -> Result<bool> {
1770        Ok(true)
1771    }
1772    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1773        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1774
1775        Ok(Box::new(cfg))
1776    }
1777}
1778
1779impl IsqModelLoader for Gemma2Loader {
1780    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1781        Ok(vec![
1782            Regex::new(r"lm_head\.(weight|bias)$")?,
1783            // Attention
1784            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1785            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1786            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1787            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1788            // MLP
1789            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1790            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1791            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1792        ])
1793    }
1794    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1795        self.isq_layer_regexes(config)
1796    }
1797}
1798
1799impl DeviceMappedModelLoader for Gemma2Loader {
1800    fn mapped_max_act_size_elems(
1801        &self,
1802        config: &str,
1803        params: &AutoDeviceMapParams,
1804        prompt_chunksize: usize,
1805    ) -> Result<usize> {
1806        let AutoDeviceMapParams::Text {
1807            max_seq_len: _,
1808            max_batch_size,
1809        } = params
1810        else {
1811            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1812        };
1813
1814        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1815
1816        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1817    }
1818    fn non_mapped_max_act_size_elems(
1819        &self,
1820        _config: &str,
1821        _params: &AutoDeviceMapParams,
1822    ) -> Result<usize> {
1823        Ok(0)
1824    }
1825
1826    fn non_mapped_size_in_bytes(
1827        &self,
1828        config: &str,
1829        dtype: DType,
1830        weight_pack_factor: usize,
1831    ) -> Result<usize> {
1832        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1833
1834        let elems = {
1835            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1836            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1837            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1838                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1839            } else {
1840                0
1841            };
1842            let norm = cfg.hidden_size;
1843            embed_tokens + lm_head + norm
1844        };
1845        Ok(elems * dtype.size_in_bytes())
1846    }
1847
1848    fn layer_sizes_in_bytes(
1849        &self,
1850        config: &str,
1851        dtype: DType,
1852        weight_pack_factor: usize,
1853    ) -> Result<Vec<usize>> {
1854        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1855
1856        let per_layer_elems = {
1857            let input_layernorm = cfg.hidden_size;
1858            let post_attention_layernorm = cfg.hidden_size;
1859
1860            let size_in = cfg.hidden_size;
1861            let size_q = cfg.head_dim * cfg.num_attention_heads;
1862            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1863            let q_proj =
1864                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1865            let k_proj =
1866                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1867            let v_proj =
1868                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1869            let o_proj =
1870                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1871
1872            let h_size = cfg.hidden_size;
1873            let i_size = cfg.intermediate_size;
1874            let gate_proj = h_size * i_size / weight_pack_factor;
1875            let up_proj = h_size * i_size / weight_pack_factor;
1876            let down_proj = i_size * h_size / weight_pack_factor;
1877
1878            input_layernorm
1879                + post_attention_layernorm
1880                + q_proj
1881                + k_proj
1882                + v_proj
1883                + o_proj
1884                + gate_proj
1885                + up_proj
1886                + down_proj
1887        };
1888        Ok(vec![
1889            per_layer_elems * dtype.size_in_bytes();
1890            cfg.num_hidden_layers
1891        ])
1892    }
1893
1894    fn num_layers(&self, config: &str) -> Result<usize> {
1895        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1896
1897        Ok(cfg.num_hidden_layers)
1898    }
1899    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1900        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1901
1902        let cfg = ModelConfigMetadata {
1903            max_seq_len: cfg.max_position_embeddings,
1904            num_layers: cfg.num_hidden_layers,
1905            hidden_size: cfg.hidden_size,
1906            num_kv_heads: cfg.num_key_value_heads,
1907            num_attn_heads: cfg.num_attention_heads,
1908            sliding_window: None, // None to be more forgiving, some do not
1909            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1910            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1911        };
1912
1913        Ok(Box::new(cfg))
1914    }
1915}
1916
1917// ======================== Starcoder2 loader
1918
1919/// [`NormalLoader`] for a Starcoder2 model.
1920///
1921/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1922pub struct Starcoder2Loader;
1923
1924impl NormalModelLoader for Starcoder2Loader {
1925    fn load(
1926        &self,
1927        config: &str,
1928        vb: ShardedVarBuilder,
1929        normal_loading_metadata: NormalLoadingMetadata,
1930        attention_mechanism: AttentionImplementation,
1931    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1932        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1933
1934        Ok(Box::new(models::starcoder2::Model::new(
1935            &cfg,
1936            vb,
1937            self.is_gptx(config)?,
1938            normal_loading_metadata,
1939            attention_mechanism,
1940        )?))
1941    }
1942    fn load_xlora(
1943        &self,
1944        config: &str,
1945        vb: ShardedVarBuilder,
1946        lora_config: &[((String, String), LoraConfig)],
1947        xlora_config: Option<XLoraConfig>,
1948        xlora_ordering: Ordering,
1949        normal_loading_metadata: NormalLoadingMetadata,
1950        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1951    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1952        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1953
1954        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1955            &cfg,
1956            vb,
1957            lora_config,
1958            xlora_config,
1959            xlora_ordering,
1960            self.is_gptx(config)?,
1961            normal_loading_metadata,
1962            preload_adapters,
1963        )?))
1964    }
1965    fn is_gptx(&self, _: &str) -> Result<bool> {
1966        Ok(true)
1967    }
1968    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1969        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1970
1971        Ok(Box::new(cfg))
1972    }
1973}
1974
1975impl IsqModelLoader for Starcoder2Loader {
1976    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1977        Ok(vec![
1978            Regex::new(r"lm_head\.(weight|bias)$")?,
1979            // Attention
1980            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1981            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1982            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1983            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1984            // MLP
1985            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1986            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
1987        ])
1988    }
1989    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1990        self.isq_layer_regexes(config)
1991    }
1992}
1993
1994impl DeviceMappedModelLoader for Starcoder2Loader {
1995    fn mapped_max_act_size_elems(
1996        &self,
1997        config: &str,
1998        params: &AutoDeviceMapParams,
1999        prompt_chunksize: usize,
2000    ) -> Result<usize> {
2001        let AutoDeviceMapParams::Text {
2002            max_seq_len: _,
2003            max_batch_size,
2004        } = params
2005        else {
2006            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2007        };
2008
2009        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2010
2011        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2012    }
2013    fn non_mapped_max_act_size_elems(
2014        &self,
2015        _config: &str,
2016        _params: &AutoDeviceMapParams,
2017    ) -> Result<usize> {
2018        Ok(0)
2019    }
2020
2021    fn non_mapped_size_in_bytes(
2022        &self,
2023        config: &str,
2024        dtype: DType,
2025        weight_pack_factor: usize,
2026    ) -> Result<usize> {
2027        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2028
2029        let elems = {
2030            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2031            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2032            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2033                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2034            } else {
2035                0
2036            };
2037            let norm = cfg.hidden_size + cfg.hidden_size;
2038            embed_tokens + lm_head + norm
2039        };
2040        Ok(elems * dtype.size_in_bytes())
2041    }
2042
2043    fn layer_sizes_in_bytes(
2044        &self,
2045        config: &str,
2046        dtype: DType,
2047        weight_pack_factor: usize,
2048    ) -> Result<Vec<usize>> {
2049        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2050
2051        let per_layer_elems = {
2052            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2053            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2054
2055            let size_in = cfg.hidden_size;
2056            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2057            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2058            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2059            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2060            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2061            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2062
2063            let h_size = cfg.hidden_size;
2064            let i_size = cfg.intermediate_size;
2065            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2066            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2067
2068            input_layernorm
2069                + post_attention_layernorm
2070                + q_proj
2071                + k_proj
2072                + v_proj
2073                + o_proj
2074                + fc1
2075                + fc2
2076        };
2077        Ok(vec![
2078            per_layer_elems * dtype.size_in_bytes();
2079            cfg.num_hidden_layers
2080        ])
2081    }
2082
2083    fn num_layers(&self, config: &str) -> Result<usize> {
2084        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2085
2086        Ok(cfg.num_hidden_layers)
2087    }
2088
2089    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2090        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2091
2092        let cfg = ModelConfigMetadata {
2093            max_seq_len: cfg.max_position_embeddings,
2094            num_layers: cfg.num_hidden_layers,
2095            hidden_size: cfg.hidden_size,
2096            num_kv_heads: cfg.num_key_value_heads,
2097            num_attn_heads: cfg.num_attention_heads,
2098            sliding_window: cfg.sliding_window,
2099            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2100            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2101        };
2102
2103        Ok(Box::new(cfg))
2104    }
2105}
2106
2107// ======================== Phi3 loader
2108
2109/// [`NormalLoader`] for a Phi 3.5 MoE model.
2110///
2111/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2112pub struct Phi3_5MoELoader;
2113
2114impl NormalModelLoader for Phi3_5MoELoader {
2115    fn load(
2116        &self,
2117        config: &str,
2118        vb: ShardedVarBuilder,
2119        normal_loading_metadata: NormalLoadingMetadata,
2120        attention_mechanism: AttentionImplementation,
2121    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2122        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2123
2124        Ok(Box::new(models::phi3_5_moe::Model::new(
2125            &cfg,
2126            vb,
2127            self.is_gptx(config)?,
2128            normal_loading_metadata,
2129            attention_mechanism,
2130        )?))
2131    }
2132    fn load_xlora(
2133        &self,
2134        config: &str,
2135        vb: ShardedVarBuilder,
2136        lora_config: &[((String, String), LoraConfig)],
2137        xlora_config: Option<XLoraConfig>,
2138        xlora_ordering: Ordering,
2139        normal_loading_metadata: NormalLoadingMetadata,
2140        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2141    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2142        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2143
2144        Ok(Box::new(xlora_models::XLoraPhi3::new(
2145            &cfg,
2146            vb,
2147            lora_config,
2148            xlora_config,
2149            xlora_ordering,
2150            self.is_gptx(config)?,
2151            normal_loading_metadata,
2152            preload_adapters,
2153        )?))
2154    }
2155    fn is_gptx(&self, _: &str) -> Result<bool> {
2156        Ok(true)
2157    }
2158    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2159        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2160
2161        Ok(Box::new(cfg))
2162    }
2163}
2164
2165impl IsqModelLoader for Phi3_5MoELoader {
2166    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2167        Ok(vec![
2168            Regex::new(r"lm_head\.(weight|bias)$")?,
2169            // Attention
2170            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2171            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2172            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2173            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2174            // MLP
2175            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2176            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2177            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2178        ])
2179    }
2180    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2181        self.isq_layer_regexes(config)
2182    }
2183
2184    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2185        Ok(vec![
2186            Regex::new(r"lm_head\.(weight|bias)$")?,
2187            // MLP
2188            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2189            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2190            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2191        ])
2192    }
2193    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2194        self.isq_layer_regexes_moqe(config)
2195    }
2196}
2197
2198impl DeviceMappedModelLoader for Phi3_5MoELoader {
2199    fn mapped_max_act_size_elems(
2200        &self,
2201        config: &str,
2202        params: &AutoDeviceMapParams,
2203        prompt_chunksize: usize,
2204    ) -> Result<usize> {
2205        let AutoDeviceMapParams::Text {
2206            max_seq_len: _,
2207            max_batch_size,
2208        } = params
2209        else {
2210            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2211        };
2212
2213        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2214
2215        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2216    }
2217    fn non_mapped_max_act_size_elems(
2218        &self,
2219        _config: &str,
2220        _params: &AutoDeviceMapParams,
2221    ) -> Result<usize> {
2222        Ok(0)
2223    }
2224
2225    fn non_mapped_size_in_bytes(
2226        &self,
2227        config: &str,
2228        dtype: DType,
2229        weight_pack_factor: usize,
2230    ) -> Result<usize> {
2231        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233        let elems = {
2234            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2235            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2236            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2237                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2238            } else {
2239                0
2240            };
2241            let norm = cfg.hidden_size;
2242            embed_tokens + lm_head + norm
2243        };
2244        Ok(elems * dtype.size_in_bytes())
2245    }
2246
2247    fn layer_sizes_in_bytes(
2248        &self,
2249        config: &str,
2250        dtype: DType,
2251        weight_pack_factor: usize,
2252    ) -> Result<Vec<usize>> {
2253        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2254
2255        let per_layer_elems = {
2256            let input_layernorm = cfg.hidden_size;
2257            let post_attention_layernorm = cfg.hidden_size;
2258
2259            let size_in = cfg.hidden_size;
2260            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2261            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2262            let q_proj =
2263                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2264            let k_proj =
2265                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2266            let v_proj =
2267                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2268            let o_proj =
2269                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2270
2271            let moe_block = {
2272                let gate = cfg.hidden_size * cfg.num_local_experts;
2273                // Assume quantizing weight pack factor
2274                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2275                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2276                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2277                gate + cfg.num_local_experts * w1
2278                    + cfg.num_local_experts * w2
2279                    + cfg.num_local_experts * w3
2280            };
2281
2282            input_layernorm
2283                + post_attention_layernorm
2284                + q_proj
2285                + k_proj
2286                + v_proj
2287                + o_proj
2288                + moe_block
2289        };
2290        Ok(vec![
2291            per_layer_elems * dtype.size_in_bytes();
2292            cfg.num_hidden_layers
2293        ])
2294    }
2295
2296    fn num_layers(&self, config: &str) -> Result<usize> {
2297        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2298
2299        Ok(cfg.num_hidden_layers)
2300    }
2301
2302    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2303        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2304
2305        let cfg = ModelConfigMetadata {
2306            max_seq_len: cfg.max_position_embeddings,
2307            num_layers: cfg.num_hidden_layers,
2308            hidden_size: cfg.hidden_size,
2309            num_kv_heads: cfg.num_key_value_heads,
2310            num_attn_heads: cfg.num_attention_heads,
2311            sliding_window: cfg.sliding_window,
2312            k_head_dim: cfg.head_dim(),
2313            v_head_dim: cfg.head_dim(),
2314        };
2315
2316        Ok(Box::new(cfg))
2317    }
2318}
2319
2320/// [`NormalLoader`] for a DeepSeekV2 model.
2321///
2322/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2323pub struct DeepSeekV2Loader;
2324
2325impl NormalModelLoader for DeepSeekV2Loader {
2326    fn load(
2327        &self,
2328        config: &str,
2329        vb: ShardedVarBuilder,
2330        normal_loading_metadata: NormalLoadingMetadata,
2331        attention_mechanism: AttentionImplementation,
2332    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2333        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2334
2335        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2336            &cfg,
2337            vb,
2338            self.is_gptx(config)?,
2339            normal_loading_metadata,
2340            attention_mechanism,
2341        )?))
2342    }
2343    fn load_xlora(
2344        &self,
2345        _config: &str,
2346        _vb: ShardedVarBuilder,
2347        _lora_config: &[((String, String), LoraConfig)],
2348        _xlora_config: Option<XLoraConfig>,
2349        _xlora_ordering: Ordering,
2350        _normal_loading_metadata: NormalLoadingMetadata,
2351        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2352    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2353        todo!()
2354    }
2355    fn is_gptx(&self, _: &str) -> Result<bool> {
2356        Ok(true)
2357    }
2358    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2359        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2360        Ok(Box::new(cfg))
2361    }
2362}
2363
2364impl IsqModelLoader for DeepSeekV2Loader {
2365    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2366        let mut data = vec![
2367            Regex::new(r"lm_head\.(weight|bias)$")?,
2368            // Attention
2369            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2370            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2371            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2372        ];
2373        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2374        if cfg.q_lora_rank.is_some() {
2375            data.extend(vec![
2376                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2377                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2378            ]);
2379        } else {
2380            data.push(Regex::new(
2381                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2382            )?);
2383        }
2384        for layer_idx in 0..cfg.num_hidden_layers {
2385            if cfg.n_routed_experts.is_some()
2386                && layer_idx >= cfg.first_k_dense_replace
2387                && layer_idx % cfg.moe_layer_freq == 0
2388            {
2389                for i in 0..cfg.n_routed_experts.unwrap() {
2390                    data.extend(vec![
2391                        Regex::new(&format!(
2392                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2393                        ))?,
2394                        Regex::new(&format!(
2395                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2396                        ))?,
2397                        Regex::new(&format!(
2398                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2399                        ))?,
2400                    ]);
2401                }
2402                if cfg.n_shared_experts.is_some() {
2403                    data.extend(vec![
2404                        Regex::new(&format!(
2405                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2406                        ))?,
2407                        Regex::new(&format!(
2408                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2409                        ))?,
2410                        Regex::new(&format!(
2411                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2412                        ))?,
2413                    ]);
2414                }
2415            } else {
2416                data.extend(vec![
2417                    Regex::new(&format!(
2418                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2419                    ))?,
2420                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2421                    Regex::new(&format!(
2422                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2423                    ))?,
2424                ]);
2425            };
2426        }
2427        Ok(data)
2428    }
2429    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2430        self.isq_layer_regexes(config)
2431    }
2432
2433    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2434        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2435        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2436        for layer_idx in 0..cfg.num_hidden_layers {
2437            if cfg.n_routed_experts.is_some()
2438                && layer_idx >= cfg.first_k_dense_replace
2439                && layer_idx % cfg.moe_layer_freq == 0
2440            {
2441                for i in 0..cfg.n_routed_experts.unwrap() {
2442                    data.extend(vec![
2443                        Regex::new(&format!(
2444                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2445                        ))?,
2446                        Regex::new(&format!(
2447                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2448                        ))?,
2449                        Regex::new(&format!(
2450                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2451                        ))?,
2452                    ]);
2453                }
2454                if cfg.n_shared_experts.is_some() {
2455                    data.extend(vec![
2456                        Regex::new(&format!(
2457                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2458                        ))?,
2459                        Regex::new(&format!(
2460                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2461                        ))?,
2462                        Regex::new(&format!(
2463                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2464                        ))?,
2465                    ]);
2466                }
2467            } else {
2468                data.extend(vec![
2469                    Regex::new(&format!(
2470                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2471                    ))?,
2472                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2473                    Regex::new(&format!(
2474                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2475                    ))?,
2476                ]);
2477            };
2478        }
2479        Ok(data)
2480    }
2481    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2482        self.isq_layer_regexes_moqe(config)
2483    }
2484}
2485
2486impl DeviceMappedModelLoader for DeepSeekV2Loader {
2487    fn mapped_max_act_size_elems(
2488        &self,
2489        config: &str,
2490        params: &AutoDeviceMapParams,
2491        prompt_chunksize: usize,
2492    ) -> Result<usize> {
2493        let AutoDeviceMapParams::Text {
2494            max_seq_len: _,
2495            max_batch_size,
2496        } = params
2497        else {
2498            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2499        };
2500
2501        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2502
2503        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2504    }
2505    fn non_mapped_max_act_size_elems(
2506        &self,
2507        _config: &str,
2508        _params: &AutoDeviceMapParams,
2509    ) -> Result<usize> {
2510        Ok(0)
2511    }
2512
2513    fn non_mapped_size_in_bytes(
2514        &self,
2515        config: &str,
2516        dtype: DType,
2517        weight_pack_factor: usize,
2518    ) -> Result<usize> {
2519        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2520        let elems = {
2521            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2522            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2523            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2524                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2525            } else {
2526                0
2527            };
2528            let norm = cfg.hidden_size;
2529            embed_tokens + lm_head + norm
2530        };
2531        Ok(elems * dtype.size_in_bytes())
2532    }
2533
2534    fn layer_sizes_in_bytes(
2535        &self,
2536        config: &str,
2537        dtype: DType,
2538        weight_pack_factor: usize,
2539    ) -> Result<Vec<usize>> {
2540        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2541        let mut per_layer_elems = Vec::new();
2542
2543        for layer_idx in 0..cfg.num_hidden_layers {
2544            let input_layernorm = cfg.hidden_size;
2545            let post_attention_layernorm = cfg.hidden_size;
2546
2547            let q_proj = match cfg.q_lora_rank {
2548                Some(lora_rank) => {
2549                    let a = cfg.hidden_size * lora_rank;
2550                    let norm = lora_rank;
2551                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2552                    a + norm + b
2553                }
2554                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2555            };
2556            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2557                / weight_pack_factor
2558                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2559            let kv_a_layernorm = cfg.kv_lora_rank;
2560            let kv_b_proj = cfg.kv_lora_rank
2561                * cfg.num_attention_heads
2562                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2563                / weight_pack_factor;
2564            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2565                / weight_pack_factor
2566                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2567
2568            let moe_block = {
2569                let mut sum = 0;
2570                if cfg.n_routed_experts.is_some()
2571                    && layer_idx >= cfg.first_k_dense_replace
2572                    && layer_idx % cfg.moe_layer_freq == 0
2573                {
2574                    let h_size = cfg.hidden_size;
2575                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2576                        * cfg.n_routed_experts.unwrap();
2577                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2578                        * cfg.n_routed_experts.unwrap();
2579                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2580                        * cfg.n_routed_experts.unwrap();
2581                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2582                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2583                            / weight_pack_factor;
2584                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2585                            / weight_pack_factor;
2586                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2587                            / weight_pack_factor;
2588                        gate_proj + up_proj + down_proj
2589                    } else {
2590                        0
2591                    };
2592                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2593                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2594                } else {
2595                    let h_size = cfg.hidden_size;
2596                    let i_size = cfg.intermediate_size;
2597                    let gate_proj = h_size * i_size / weight_pack_factor;
2598                    let up_proj = h_size * i_size / weight_pack_factor;
2599                    let down_proj = i_size * h_size / weight_pack_factor;
2600                    sum += gate_proj + up_proj + down_proj;
2601                }
2602                sum
2603            };
2604
2605            per_layer_elems.push(
2606                input_layernorm
2607                    + post_attention_layernorm
2608                    + q_proj
2609                    + kv_a_layernorm
2610                    + kv_a_proj_with_mqa
2611                    + kv_b_proj
2612                    + o_proj
2613                    + moe_block,
2614            );
2615        }
2616
2617        Ok(per_layer_elems
2618            .into_iter()
2619            .map(|x| x * dtype.size_in_bytes())
2620            .collect())
2621    }
2622
2623    fn num_layers(&self, config: &str) -> Result<usize> {
2624        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2625        Ok(cfg.num_hidden_layers)
2626    }
2627
2628    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2629        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2630
2631        let cfg = ModelConfigMetadata {
2632            max_seq_len: cfg.max_position_embeddings,
2633            num_layers: cfg.num_hidden_layers,
2634            hidden_size: cfg.hidden_size,
2635            num_kv_heads: cfg.num_attention_heads,
2636            num_attn_heads: cfg.num_attention_heads,
2637            sliding_window: None,
2638            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2639            v_head_dim: cfg.v_head_dim,
2640        };
2641
2642        Ok(Box::new(cfg))
2643    }
2644}
2645
2646/// [`NormalLoader`] for a DeepSeekV3 model.
2647///
2648/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2649pub struct DeepSeekV3Loader;
2650
2651impl NormalModelLoader for DeepSeekV3Loader {
2652    fn load(
2653        &self,
2654        config: &str,
2655        vb: ShardedVarBuilder,
2656        normal_loading_metadata: NormalLoadingMetadata,
2657        attention_mechanism: AttentionImplementation,
2658    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2659        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2660        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2661            &cfg,
2662            vb,
2663            self.is_gptx(config)?,
2664            normal_loading_metadata,
2665            attention_mechanism,
2666        )?))
2667    }
2668    fn load_xlora(
2669        &self,
2670        _config: &str,
2671        _vb: ShardedVarBuilder,
2672        _lora_config: &[((String, String), LoraConfig)],
2673        _xlora_config: Option<XLoraConfig>,
2674        _xlora_ordering: Ordering,
2675        _normal_loading_metadata: NormalLoadingMetadata,
2676        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2677    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2678        todo!()
2679    }
2680    fn is_gptx(&self, _: &str) -> Result<bool> {
2681        Ok(true)
2682    }
2683    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2684        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2685        Ok(Box::new(cfg))
2686    }
2687}
2688
2689impl IsqModelLoader for DeepSeekV3Loader {
2690    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2691        let mut data = vec![
2692            Regex::new(r"lm_head\.(weight|bias)$")?,
2693            // Attention
2694            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2695            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2696            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2697        ];
2698        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2699        if cfg.q_lora_rank.is_some() {
2700            data.extend(vec![
2701                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2702                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2703            ]);
2704        } else {
2705            data.push(Regex::new(
2706                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2707            )?);
2708        }
2709        for layer_idx in 0..cfg.num_hidden_layers {
2710            if cfg.n_routed_experts.is_some()
2711                && layer_idx >= cfg.first_k_dense_replace
2712                && layer_idx % cfg.moe_layer_freq == 0
2713            {
2714                for i in 0..cfg.n_routed_experts.unwrap() {
2715                    data.extend(vec![
2716                        Regex::new(&format!(
2717                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2718                        ))?,
2719                        Regex::new(&format!(
2720                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2721                        ))?,
2722                        Regex::new(&format!(
2723                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2724                        ))?,
2725                    ]);
2726                }
2727                if cfg.n_shared_experts.is_some() {
2728                    data.extend(vec![
2729                        Regex::new(&format!(
2730                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2731                        ))?,
2732                        Regex::new(&format!(
2733                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2734                        ))?,
2735                        Regex::new(&format!(
2736                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2737                        ))?,
2738                    ]);
2739                }
2740            } else {
2741                data.extend(vec![
2742                    Regex::new(&format!(
2743                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2744                    ))?,
2745                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2746                    Regex::new(&format!(
2747                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2748                    ))?,
2749                ]);
2750            };
2751        }
2752        Ok(data)
2753    }
2754    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2755        self.isq_layer_regexes(config)
2756    }
2757
2758    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2759        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2760        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2761        for layer_idx in 0..cfg.num_hidden_layers {
2762            if cfg.n_routed_experts.is_some()
2763                && layer_idx >= cfg.first_k_dense_replace
2764                && layer_idx % cfg.moe_layer_freq == 0
2765            {
2766                for i in 0..cfg.n_routed_experts.unwrap() {
2767                    data.extend(vec![
2768                        Regex::new(&format!(
2769                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2770                        ))?,
2771                        Regex::new(&format!(
2772                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2773                        ))?,
2774                        Regex::new(&format!(
2775                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2776                        ))?,
2777                    ]);
2778                }
2779                if cfg.n_shared_experts.is_some() {
2780                    data.extend(vec![
2781                        Regex::new(&format!(
2782                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2783                        ))?,
2784                        Regex::new(&format!(
2785                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2786                        ))?,
2787                        Regex::new(&format!(
2788                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2789                        ))?,
2790                    ]);
2791                }
2792            } else {
2793                data.extend(vec![
2794                    Regex::new(&format!(
2795                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2796                    ))?,
2797                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2798                    Regex::new(&format!(
2799                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2800                    ))?,
2801                ]);
2802            };
2803        }
2804        Ok(data)
2805    }
2806    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2807        self.isq_layer_regexes_moqe(config)
2808    }
2809}
2810
2811impl DeviceMappedModelLoader for DeepSeekV3Loader {
2812    fn mapped_max_act_size_elems(
2813        &self,
2814        config: &str,
2815        params: &AutoDeviceMapParams,
2816        prompt_chunksize: usize,
2817    ) -> Result<usize> {
2818        let AutoDeviceMapParams::Text {
2819            max_seq_len: _,
2820            max_batch_size,
2821        } = params
2822        else {
2823            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2824        };
2825
2826        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2827
2828        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2829    }
2830    fn non_mapped_max_act_size_elems(
2831        &self,
2832        _config: &str,
2833        _params: &AutoDeviceMapParams,
2834    ) -> Result<usize> {
2835        Ok(0)
2836    }
2837
2838    fn non_mapped_size_in_bytes(
2839        &self,
2840        config: &str,
2841        dtype: DType,
2842        weight_pack_factor: usize,
2843    ) -> Result<usize> {
2844        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2845        let elems = {
2846            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2847            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2848            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2849                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2850            } else {
2851                0
2852            };
2853            let norm = cfg.hidden_size;
2854            embed_tokens + lm_head + norm
2855        };
2856        Ok(elems * dtype.size_in_bytes())
2857    }
2858
2859    fn layer_sizes_in_bytes(
2860        &self,
2861        config: &str,
2862        dtype: DType,
2863        weight_pack_factor: usize,
2864    ) -> Result<Vec<usize>> {
2865        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2866        let mut per_layer_elems = Vec::new();
2867
2868        for layer_idx in 0..cfg.num_hidden_layers {
2869            let input_layernorm = cfg.hidden_size;
2870            let post_attention_layernorm = cfg.hidden_size;
2871
2872            let q_proj = match cfg.q_lora_rank {
2873                Some(lora_rank) => {
2874                    let a = cfg.hidden_size * lora_rank;
2875                    let norm = lora_rank;
2876                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2877                    a + norm + b
2878                }
2879                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2880            };
2881            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2882                / weight_pack_factor
2883                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2884            let kv_a_layernorm = cfg.kv_lora_rank;
2885            let kv_b_proj = cfg.kv_lora_rank
2886                * cfg.num_attention_heads
2887                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2888                / weight_pack_factor;
2889            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2890                / weight_pack_factor
2891                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2892
2893            let moe_block = {
2894                let mut sum = 0;
2895                if cfg.n_routed_experts.is_some()
2896                    && layer_idx >= cfg.first_k_dense_replace
2897                    && layer_idx % cfg.moe_layer_freq == 0
2898                {
2899                    let h_size = cfg.hidden_size;
2900                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2901                        * cfg.n_routed_experts.unwrap();
2902                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2903                        * cfg.n_routed_experts.unwrap();
2904                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2905                        * cfg.n_routed_experts.unwrap();
2906                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2907                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2908                            / weight_pack_factor;
2909                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2910                            / weight_pack_factor;
2911                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2912                            / weight_pack_factor;
2913                        gate_proj + up_proj + down_proj
2914                    } else {
2915                        0
2916                    };
2917                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2918                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2919                } else {
2920                    let h_size = cfg.hidden_size;
2921                    let i_size = cfg.intermediate_size;
2922                    let gate_proj = h_size * i_size / weight_pack_factor;
2923                    let up_proj = h_size * i_size / weight_pack_factor;
2924                    let down_proj = i_size * h_size / weight_pack_factor;
2925                    sum += gate_proj + up_proj + down_proj;
2926                }
2927                sum
2928            };
2929
2930            per_layer_elems.push(
2931                input_layernorm
2932                    + post_attention_layernorm
2933                    + q_proj
2934                    + kv_a_layernorm
2935                    + kv_a_proj_with_mqa
2936                    + kv_b_proj
2937                    + o_proj
2938                    + moe_block,
2939            );
2940        }
2941
2942        Ok(per_layer_elems
2943            .into_iter()
2944            .map(|x| x * dtype.size_in_bytes())
2945            .collect())
2946    }
2947
2948    fn num_layers(&self, config: &str) -> Result<usize> {
2949        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2950        Ok(cfg.num_hidden_layers)
2951    }
2952
2953    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2954        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2955
2956        let cfg = ModelConfigMetadata {
2957            max_seq_len: cfg.max_position_embeddings,
2958            num_layers: cfg.num_hidden_layers,
2959            hidden_size: cfg.hidden_size,
2960            num_kv_heads: cfg.num_attention_heads,
2961            num_attn_heads: cfg.num_attention_heads,
2962            sliding_window: None,
2963            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2964            v_head_dim: cfg.v_head_dim,
2965        };
2966
2967        Ok(Box::new(cfg))
2968    }
2969}
2970
2971/// [`NormalLoader`] for a Qwen 3 model.
2972///
2973/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2974pub struct Qwen3Loader;
2975
2976impl NormalModelLoader for Qwen3Loader {
2977    fn load(
2978        &self,
2979        config: &str,
2980        vb: ShardedVarBuilder,
2981        normal_loading_metadata: NormalLoadingMetadata,
2982        attention_mechanism: AttentionImplementation,
2983    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2984        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2985
2986        Ok(Box::new(models::qwen3::Model::new(
2987            &cfg,
2988            vb,
2989            self.is_gptx(config)?,
2990            normal_loading_metadata,
2991            attention_mechanism,
2992        )?))
2993    }
2994    fn load_xlora(
2995        &self,
2996        _config: &str,
2997        _vb: ShardedVarBuilder,
2998        _lora_config: &[((String, String), LoraConfig)],
2999        _xlora_config: Option<XLoraConfig>,
3000        _xlora_ordering: Ordering,
3001        _normal_loading_metadata: NormalLoadingMetadata,
3002        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3003    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3004        todo!()
3005    }
3006    fn is_gptx(&self, _: &str) -> Result<bool> {
3007        Ok(true)
3008    }
3009    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3010        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3011
3012        Ok(Box::new(cfg))
3013    }
3014}
3015
3016impl IsqModelLoader for Qwen3Loader {
3017    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3018        Ok(vec![
3019            Regex::new(r"lm_head\.(weight|bias)$")?,
3020            // Attention
3021            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3022            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3023            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3024            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3025            // MLP
3026            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3027            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3028            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3029        ])
3030    }
3031    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3032        self.isq_layer_regexes(config)
3033    }
3034}
3035
3036impl DeviceMappedModelLoader for Qwen3Loader {
3037    fn mapped_max_act_size_elems(
3038        &self,
3039        config: &str,
3040        params: &AutoDeviceMapParams,
3041        prompt_chunksize: usize,
3042    ) -> Result<usize> {
3043        let AutoDeviceMapParams::Text {
3044            max_seq_len: _,
3045            max_batch_size,
3046        } = params
3047        else {
3048            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3049        };
3050
3051        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3052
3053        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3054    }
3055    fn non_mapped_max_act_size_elems(
3056        &self,
3057        _config: &str,
3058        _params: &AutoDeviceMapParams,
3059    ) -> Result<usize> {
3060        Ok(0)
3061    }
3062
3063    fn non_mapped_size_in_bytes(
3064        &self,
3065        config: &str,
3066        dtype: DType,
3067        weight_pack_factor: usize,
3068    ) -> Result<usize> {
3069        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3070        let elems = {
3071            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3072            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3073            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3074                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3075            } else {
3076                0
3077            };
3078            let norm = cfg.hidden_size;
3079            embed_tokens + lm_head + norm
3080        };
3081        Ok(elems * dtype.size_in_bytes())
3082    }
3083
3084    fn layer_sizes_in_bytes(
3085        &self,
3086        config: &str,
3087        dtype: DType,
3088        weight_pack_factor: usize,
3089    ) -> Result<Vec<usize>> {
3090        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3091        let per_layer_elems = {
3092            let input_layernorm = cfg.hidden_size;
3093            let post_attention_layernorm = cfg.hidden_size;
3094
3095            let size_in = cfg.hidden_size;
3096            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3097            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3098            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3099            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3100            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3101            let o_proj = size_q * size_in / weight_pack_factor;
3102
3103            let h_size = cfg.hidden_size;
3104            let i_size = cfg.intermediate_size;
3105            let gate_proj = h_size * i_size / weight_pack_factor;
3106            let up_proj = h_size * i_size / weight_pack_factor;
3107            let down_proj = i_size * h_size / weight_pack_factor;
3108
3109            let q_norm = cfg.head_dim();
3110            let k_norm = cfg.head_dim();
3111
3112            input_layernorm
3113                + post_attention_layernorm
3114                + q_proj
3115                + k_proj
3116                + v_proj
3117                + o_proj
3118                + gate_proj
3119                + up_proj
3120                + down_proj
3121                + q_norm
3122                + k_norm
3123        };
3124        Ok(vec![
3125            per_layer_elems * dtype.size_in_bytes();
3126            cfg.num_hidden_layers
3127        ])
3128    }
3129
3130    fn num_layers(&self, config: &str) -> Result<usize> {
3131        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3132        Ok(cfg.num_hidden_layers)
3133    }
3134
3135    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3136        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3137
3138        let cfg = ModelConfigMetadata {
3139            max_seq_len: cfg.max_position_embeddings,
3140            num_layers: cfg.num_hidden_layers,
3141            hidden_size: cfg.hidden_size,
3142            num_kv_heads: cfg.num_key_value_heads,
3143            num_attn_heads: cfg.num_attention_heads,
3144            sliding_window: cfg.sliding_window,
3145            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3146            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3147        };
3148
3149        Ok(Box::new(cfg))
3150    }
3151}
3152
3153/// [`NormalLoader`] for a GLM 4 model.
3154///
3155/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3156pub struct GLM4Loader;
3157
3158impl NormalModelLoader for GLM4Loader {
3159    fn load(
3160        &self,
3161        config: &str,
3162        vb: ShardedVarBuilder,
3163        normal_loading_metadata: NormalLoadingMetadata,
3164        attention_mechanism: AttentionImplementation,
3165    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3166        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3167
3168        Ok(Box::new(models::glm4::Model::new(
3169            &cfg,
3170            vb,
3171            self.is_gptx(config)?,
3172            normal_loading_metadata,
3173            attention_mechanism,
3174        )?))
3175    }
3176    fn load_xlora(
3177        &self,
3178        _config: &str,
3179        _vb: ShardedVarBuilder,
3180        _lora_config: &[((String, String), LoraConfig)],
3181        _xlora_config: Option<XLoraConfig>,
3182        _xlora_ordering: Ordering,
3183        _normal_loading_metadata: NormalLoadingMetadata,
3184        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3185    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3186        todo!()
3187    }
3188    fn is_gptx(&self, _: &str) -> Result<bool> {
3189        Ok(true)
3190    }
3191    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3192        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3193
3194        Ok(Box::new(cfg))
3195    }
3196}
3197
3198impl IsqModelLoader for GLM4Loader {
3199    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3200        Ok(vec![
3201            Regex::new(r"lm_head\.(weight|bias)$")?,
3202            // Attention
3203            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3204            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3205            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3206            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3207            // MLP
3208            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3209            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3210            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3211        ])
3212    }
3213    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3214        self.isq_layer_regexes(config)
3215    }
3216}
3217
3218impl DeviceMappedModelLoader for GLM4Loader {
3219    fn mapped_max_act_size_elems(
3220        &self,
3221        config: &str,
3222        params: &AutoDeviceMapParams,
3223        prompt_chunksize: usize,
3224    ) -> Result<usize> {
3225        let AutoDeviceMapParams::Text {
3226            max_seq_len: _,
3227            max_batch_size,
3228        } = params
3229        else {
3230            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3231        };
3232
3233        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3234
3235        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3236    }
3237    fn non_mapped_max_act_size_elems(
3238        &self,
3239        _config: &str,
3240        _params: &AutoDeviceMapParams,
3241    ) -> Result<usize> {
3242        Ok(0)
3243    }
3244
3245    fn non_mapped_size_in_bytes(
3246        &self,
3247        config: &str,
3248        dtype: DType,
3249        weight_pack_factor: usize,
3250    ) -> Result<usize> {
3251        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3252        let elems = {
3253            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3254            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3255            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3256                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3257            } else {
3258                0
3259            };
3260            let norm = cfg.hidden_size;
3261            embed_tokens + lm_head + norm
3262        };
3263        Ok(elems * dtype.size_in_bytes())
3264    }
3265
3266    fn layer_sizes_in_bytes(
3267        &self,
3268        config: &str,
3269        dtype: DType,
3270        weight_pack_factor: usize,
3271    ) -> Result<Vec<usize>> {
3272        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3273        let per_layer_elems = {
3274            let input_layernorm = cfg.hidden_size;
3275            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3276
3277            let size_in = cfg.hidden_size;
3278            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3279            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3280            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3281            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3282            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3283            let o_proj = size_q * size_in / weight_pack_factor;
3284
3285            let h_size = cfg.hidden_size;
3286            let i_size = cfg.intermediate_size;
3287            let gate_proj = h_size * i_size / weight_pack_factor;
3288            let up_proj = h_size * i_size / weight_pack_factor;
3289            let down_proj = i_size * h_size / weight_pack_factor;
3290
3291            input_layernorm
3292                + post_attention_layernorm
3293                + q_proj
3294                + k_proj
3295                + v_proj
3296                + o_proj
3297                + gate_proj
3298                + up_proj
3299                + down_proj
3300        };
3301        Ok(vec![
3302            per_layer_elems * dtype.size_in_bytes();
3303            cfg.num_hidden_layers
3304        ])
3305    }
3306
3307    fn num_layers(&self, config: &str) -> Result<usize> {
3308        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3309        Ok(cfg.num_hidden_layers)
3310    }
3311
3312    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3313        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3314
3315        let cfg = ModelConfigMetadata {
3316            max_seq_len: cfg.max_position_embeddings,
3317            num_layers: cfg.num_hidden_layers,
3318            hidden_size: cfg.hidden_size,
3319            num_kv_heads: cfg.num_key_value_heads,
3320            num_attn_heads: cfg.num_attention_heads,
3321            sliding_window: cfg.sliding_window,
3322            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3323            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3324        };
3325
3326        Ok(Box::new(cfg))
3327    }
3328}
3329
3330/// [`NormalLoader`] for a Qwen 3 MoE model.
3331///
3332/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3333pub struct Qwen3MoELoader;
3334
3335impl NormalModelLoader for Qwen3MoELoader {
3336    fn load(
3337        &self,
3338        config: &str,
3339        vb: ShardedVarBuilder,
3340        normal_loading_metadata: NormalLoadingMetadata,
3341        attention_mechanism: AttentionImplementation,
3342    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3343        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3344
3345        Ok(Box::new(models::qwen3_moe::Model::new(
3346            &cfg,
3347            vb,
3348            self.is_gptx(config)?,
3349            normal_loading_metadata,
3350            attention_mechanism,
3351        )?))
3352    }
3353    fn load_xlora(
3354        &self,
3355        _config: &str,
3356        _vb: ShardedVarBuilder,
3357        _lora_config: &[((String, String), LoraConfig)],
3358        _xlora_config: Option<XLoraConfig>,
3359        _xlora_ordering: Ordering,
3360        _normal_loading_metadata: NormalLoadingMetadata,
3361        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3362    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3363        todo!()
3364    }
3365    fn is_gptx(&self, _: &str) -> Result<bool> {
3366        Ok(true)
3367    }
3368    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3369        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3370
3371        Ok(Box::new(cfg))
3372    }
3373}
3374
3375impl IsqModelLoader for Qwen3MoELoader {
3376    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3377        Ok(vec![
3378            Regex::new(r"lm_head\.(weight|bias)$")?,
3379            // Attention
3380            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3381            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3382            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3383            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3384            // MLP
3385            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3386            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3387            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3388            // MLP MoE
3389            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3390            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3391            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3392        ])
3393    }
3394    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3395        self.isq_layer_regexes(config)
3396    }
3397    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3398        self.isq_layer_regexes_moqe(config)
3399    }
3400}
3401
3402impl DeviceMappedModelLoader for Qwen3MoELoader {
3403    fn mapped_max_act_size_elems(
3404        &self,
3405        config: &str,
3406        params: &AutoDeviceMapParams,
3407        prompt_chunksize: usize,
3408    ) -> Result<usize> {
3409        let AutoDeviceMapParams::Text {
3410            max_seq_len: _,
3411            max_batch_size,
3412        } = params
3413        else {
3414            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3415        };
3416
3417        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3418
3419        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3420    }
3421    fn non_mapped_max_act_size_elems(
3422        &self,
3423        _config: &str,
3424        _params: &AutoDeviceMapParams,
3425    ) -> Result<usize> {
3426        Ok(0)
3427    }
3428
3429    fn non_mapped_size_in_bytes(
3430        &self,
3431        config: &str,
3432        dtype: DType,
3433        weight_pack_factor: usize,
3434    ) -> Result<usize> {
3435        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3436        let elems = {
3437            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3438            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3439            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3440                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3441            } else {
3442                0
3443            };
3444            let norm = cfg.hidden_size;
3445            embed_tokens + lm_head + norm
3446        };
3447        Ok(elems * dtype.size_in_bytes())
3448    }
3449
3450    fn layer_sizes_in_bytes(
3451        &self,
3452        config: &str,
3453        dtype: DType,
3454        weight_pack_factor: usize,
3455    ) -> Result<Vec<usize>> {
3456        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3457
3458        let mut layer_sizes_in_bytes = Vec::new();
3459        for layer_idx in 0..cfg.num_hidden_layers {
3460            let input_layernorm = cfg.hidden_size;
3461            let post_attention_layernorm = cfg.hidden_size;
3462
3463            let size_in = cfg.hidden_size;
3464            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3465            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3466            let q_proj = size_in * size_q / weight_pack_factor;
3467            let k_proj = size_in * size_kv / weight_pack_factor;
3468            let v_proj = size_in * size_kv / weight_pack_factor;
3469            let o_proj = size_q * size_in / weight_pack_factor;
3470
3471            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3472                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3473            {
3474                let gate_size = cfg.hidden_size * cfg.num_experts;
3475                let expert_size = {
3476                    let h_size = cfg.hidden_size;
3477                    let i_size = cfg.moe_intermediate_size;
3478                    let gate_proj = h_size * i_size / weight_pack_factor;
3479                    let up_proj = h_size * i_size / weight_pack_factor;
3480                    let down_proj = i_size * h_size / weight_pack_factor;
3481                    gate_proj + up_proj + down_proj
3482                };
3483                expert_size * cfg.num_experts + gate_size
3484            } else {
3485                let h_size = cfg.hidden_size;
3486                let i_size = cfg.intermediate_size;
3487                let gate_proj = h_size * i_size / weight_pack_factor;
3488                let up_proj = h_size * i_size / weight_pack_factor;
3489                let down_proj = i_size * h_size / weight_pack_factor;
3490                gate_proj + up_proj + down_proj
3491            };
3492
3493            let q_norm = cfg.head_dim();
3494            let k_norm = cfg.head_dim();
3495
3496            let size_elems = input_layernorm
3497                + post_attention_layernorm
3498                + q_proj
3499                + k_proj
3500                + v_proj
3501                + o_proj
3502                + mlp_size
3503                + q_norm
3504                + k_norm;
3505
3506            let size_in_bytes = size_elems * dtype.size_in_bytes();
3507            layer_sizes_in_bytes.push(size_in_bytes);
3508        }
3509
3510        Ok(layer_sizes_in_bytes)
3511    }
3512
3513    fn num_layers(&self, config: &str) -> Result<usize> {
3514        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3515        Ok(cfg.num_hidden_layers)
3516    }
3517
3518    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3519        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3520
3521        let cfg = ModelConfigMetadata {
3522            max_seq_len: cfg.max_position_embeddings,
3523            num_layers: cfg.num_hidden_layers,
3524            hidden_size: cfg.hidden_size,
3525            num_kv_heads: cfg.num_key_value_heads,
3526            num_attn_heads: cfg.num_attention_heads,
3527            sliding_window: cfg.sliding_window,
3528            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3529            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3530        };
3531
3532        Ok(Box::new(cfg))
3533    }
3534}
3535
3536// ======================== SmolLm3 loader
3537
3538/// [`NormalLoader`] for a SmolLm3 model.
3539///
3540/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3541pub struct SmolLm3Loader;
3542
3543impl NormalModelLoader for SmolLm3Loader {
3544    fn load(
3545        &self,
3546        config: &str,
3547        vb: ShardedVarBuilder,
3548        normal_loading_metadata: NormalLoadingMetadata,
3549        attention_mechanism: AttentionImplementation,
3550    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3551        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3552
3553        Ok(Box::new(models::smollm3::SmolLm3::new(
3554            &cfg,
3555            vb,
3556            self.is_gptx(config)?,
3557            normal_loading_metadata,
3558            attention_mechanism,
3559        )?))
3560    }
3561    fn load_xlora(
3562        &self,
3563        _config: &str,
3564        _vb: ShardedVarBuilder,
3565        _lora_config: &[((String, String), LoraConfig)],
3566        _xlora_config: Option<XLoraConfig>,
3567        _xlora_ordering: Ordering,
3568        _normal_loading_metadata: NormalLoadingMetadata,
3569        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3570    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3571        todo!()
3572    }
3573    fn is_gptx(&self, _: &str) -> Result<bool> {
3574        Ok(true)
3575    }
3576    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3577        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3578        Ok(Box::new(cfg))
3579    }
3580}
3581
3582impl IsqModelLoader for SmolLm3Loader {
3583    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3584        Ok(vec![
3585            Regex::new(r"lm_head\.(weight|bias)$")?,
3586            // Attention
3587            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3588            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3589            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3590            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3591            // MLP
3592            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3593            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3594            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3595        ])
3596    }
3597    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3598        self.isq_layer_regexes(config)
3599    }
3600}
3601
3602impl DeviceMappedModelLoader for SmolLm3Loader {
3603    fn mapped_max_act_size_elems(
3604        &self,
3605        config: &str,
3606        params: &AutoDeviceMapParams,
3607        prompt_chunksize: usize,
3608    ) -> Result<usize> {
3609        let AutoDeviceMapParams::Text {
3610            max_seq_len: _,
3611            max_batch_size,
3612        } = params
3613        else {
3614            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3615        };
3616
3617        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3618
3619        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3620    }
3621    fn non_mapped_max_act_size_elems(
3622        &self,
3623        _config: &str,
3624        _params: &AutoDeviceMapParams,
3625    ) -> Result<usize> {
3626        Ok(0)
3627    }
3628
3629    fn non_mapped_size_in_bytes(
3630        &self,
3631        config: &str,
3632        dtype: DType,
3633        weight_pack_factor: usize,
3634    ) -> Result<usize> {
3635        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3636
3637        let elems = {
3638            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3639            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3640            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3641                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3642            } else {
3643                0
3644            };
3645            let norm = cfg.hidden_size;
3646            embed_tokens + lm_head + norm
3647        };
3648        Ok(elems * dtype.size_in_bytes())
3649    }
3650
3651    fn layer_sizes_in_bytes(
3652        &self,
3653        config: &str,
3654        dtype: DType,
3655        weight_pack_factor: usize,
3656    ) -> Result<Vec<usize>> {
3657        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3658
3659        let per_layer_elems = {
3660            let input_layernorm = cfg.hidden_size;
3661            let post_attention_layernorm = cfg.hidden_size;
3662
3663            let size_in = cfg.hidden_size;
3664            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3665            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3666            let q_proj = size_in * size_q / weight_pack_factor;
3667            let k_proj = size_in * size_kv / weight_pack_factor;
3668            let v_proj = size_in * size_kv / weight_pack_factor;
3669            let o_proj = size_q * size_in / weight_pack_factor;
3670
3671            let h_size = cfg.hidden_size;
3672            let i_size = cfg.intermediate_size;
3673            let gate_proj = h_size * i_size / weight_pack_factor;
3674            let up_proj = h_size * i_size / weight_pack_factor;
3675            let down_proj = i_size * h_size / weight_pack_factor;
3676
3677            input_layernorm
3678                + post_attention_layernorm
3679                + q_proj
3680                + k_proj
3681                + v_proj
3682                + o_proj
3683                + gate_proj
3684                + up_proj
3685                + down_proj
3686        };
3687        Ok(vec![
3688            per_layer_elems * dtype.size_in_bytes();
3689            cfg.num_hidden_layers
3690        ])
3691    }
3692
3693    fn num_layers(&self, config: &str) -> Result<usize> {
3694        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3695
3696        Ok(cfg.num_hidden_layers)
3697    }
3698    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3699        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3700
3701        let cfg = ModelConfigMetadata {
3702            max_seq_len: cfg.max_position_embeddings,
3703            num_layers: cfg.num_hidden_layers,
3704            hidden_size: cfg.hidden_size,
3705            num_kv_heads: cfg.num_key_value_heads,
3706            num_attn_heads: cfg.num_attention_heads,
3707            sliding_window: None,
3708            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3709            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3710        };
3711
3712        Ok(Box::new(cfg))
3713    }
3714}