mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "glm4moelite")]
175    GLM4MoeLite,
176    #[serde(rename = "glm4moe")]
177    GLM4Moe,
178    #[serde(rename = "qwen3moe")]
179    Qwen3Moe,
180    #[serde(rename = "smollm3")]
181    SmolLm3,
182    #[serde(rename = "granitemoehybrid")]
183    GraniteMoeHybrid,
184    #[serde(rename = "gpt_oss")]
185    GptOss,
186}
187
188// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
189impl NormalLoaderType {
190    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
191        match name {
192            "MistralForCausalLM" => Ok(Self::Mistral),
193            "MixtralForCausalLM" => Ok(Self::Mixtral),
194            "GemmaForCausalLM" => Ok(Self::Gemma),
195            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
196            "PhiForCausalLM" => Ok(Self::Phi2),
197            "Phi3ForCausalLM" => Ok(Self::Phi3),
198            "LlamaForCausalLM" => Ok(Self::Llama),
199            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
200            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
201            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
202            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
203            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
204            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
205            "Glm4ForCausalLM" => Ok(Self::GLM4),
206            "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
207            "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
208            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
209            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
210            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
211            "GptOssForCausalLM" => Ok(Self::GptOss),
212            other => anyhow::bail!(
213                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
214            ),
215        }
216    }
217}
218
219impl FromStr for NormalLoaderType {
220    type Err = String;
221    fn from_str(s: &str) -> Result<Self, Self::Err> {
222        match s {
223            "mistral" => Ok(Self::Mistral),
224            "gemma" => Ok(Self::Gemma),
225            "mixtral" => Ok(Self::Mixtral),
226            "llama" => Ok(Self::Llama),
227            "phi2" => Ok(Self::Phi2),
228            "phi3" => Ok(Self::Phi3),
229            "qwen2" => Ok(Self::Qwen2),
230            "gemma2" => Ok(Self::Gemma2),
231            "starcoder2" => Ok(Self::Starcoder2),
232            "phi3.5moe" => Ok(Self::Phi3_5MoE),
233            "deepseekv2" => Ok(Self::DeepSeekV2),
234            "deepseekv3" => Ok(Self::DeepSeekV3),
235            "qwen3" => Ok(Self::Qwen3),
236            "glm4" => Ok(Self::GLM4),
237            "glm4moelite" => Ok(Self::GLM4MoeLite),
238            "glm4moe" => Ok(Self::GLM4Moe),
239            "qwen3moe" => Ok(Self::Qwen3Moe),
240            "smollm3" => Ok(Self::SmolLm3),
241            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
242            "gpt_oss" => Ok(Self::GptOss),
243            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
244        }
245    }
246}
247
248impl Display for NormalLoaderType {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        match self {
251            Self::Gemma => write!(f, "gemma"),
252            Self::Gemma2 => write!(f, "gemma2"),
253            Self::Llama => write!(f, "llama"),
254            Self::Mistral => write!(f, "mistral"),
255            Self::Mixtral => write!(f, "mixtral"),
256            Self::Phi2 => write!(f, "phi2"),
257            Self::Phi3 => write!(f, "phi3"),
258            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
259            Self::Qwen2 => write!(f, "qwen2"),
260            Self::Starcoder2 => write!(f, "starcoder2"),
261            Self::DeepSeekV2 => write!(f, "deepseekv2"),
262            Self::DeepSeekV3 => write!(f, "deepseekv3"),
263            Self::Qwen3 => write!(f, "qwen3"),
264            Self::GLM4 => write!(f, "glm4"),
265            Self::GLM4MoeLite => write!(f, "glm4moelite"),
266            Self::GLM4Moe => write!(f, "glm4moe"),
267            Self::Qwen3Moe => write!(f, "qwen3moe"),
268            Self::SmolLm3 => write!(f, "smollm3"),
269            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
270            Self::GptOss => write!(f, "gpt_oss"),
271        }
272    }
273}
274
275macro_rules! bias_if {
276    ($cond:expr, $size:expr) => {
277        if $cond {
278            $size
279        } else {
280            0
281        }
282    };
283}
284
285/// Load a model based on the Hugging Face Transformers -CausalLM model class
286pub struct AutoNormalLoader;
287
288#[derive(Deserialize)]
289struct AutoNormalLoaderConfig {
290    architectures: Vec<String>,
291}
292
293impl AutoNormalLoader {
294    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
295        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
296        if auto_cfg.architectures.len() != 1 {
297            anyhow::bail!("Expected to have one name for `architectures` config field.")
298        }
299
300        let name = &auto_cfg.architectures[0];
301
302        let tp = NormalLoaderType::from_causal_lm_name(name)?;
303
304        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
305
306        match tp {
307            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
308            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
309            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
310            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
311            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
312            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
313            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
314            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
315            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
316            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
317            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
318            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
319            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
320            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
321            NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
322            NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
323            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
324            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
325            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
326            NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
327        }
328    }
329}
330
331impl NormalModelLoader for AutoNormalLoader {
332    fn load(
333        &self,
334        config: &str,
335        vb: ShardedVarBuilder,
336        normal_loading_metadata: NormalLoadingMetadata,
337        attention_mechanism: AttentionImplementation,
338    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
340    }
341    fn load_xlora(
342        &self,
343        config: &str,
344        vb: ShardedVarBuilder,
345        lora_config: &[((String, String), LoraConfig)],
346        xlora_config: Option<XLoraConfig>,
347        xlora_ordering: Ordering,
348        normal_loading_metadata: NormalLoadingMetadata,
349        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
350    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
351        Self::get_loader(config)?.load_xlora(
352            config,
353            vb,
354            lora_config,
355            xlora_config,
356            xlora_ordering,
357            normal_loading_metadata,
358            preload_adapters,
359        )
360    }
361    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
362        Self::get_loader(config)?.get_config_repr(config)
363    }
364    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
365        Self::get_loader(config)?.supports_paged_attention(config)
366    }
367    fn is_gptx(&self, config: &str) -> Result<bool> {
368        Self::get_loader(config)?.is_gptx(config)
369    }
370}
371
372impl IsqModelLoader for AutoNormalLoader {
373    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
374        Self::get_loader(config)?.immediate_isq_predicates(config)
375    }
376    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
377        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
378    }
379    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
380        Self::get_loader(config)?.isq_layer_regexes(config)
381    }
382    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
383        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
384    }
385}
386
387impl DeviceMappedModelLoader for AutoNormalLoader {
388    fn non_mapped_size_in_bytes(
389        &self,
390        config: &str,
391        dtype: DType,
392        weight_pack_factor: usize,
393        _matformer_config: Option<&MatformerSliceConfig>,
394    ) -> Result<usize> {
395        Self::get_loader(config)?.non_mapped_size_in_bytes(
396            config,
397            dtype,
398            weight_pack_factor,
399            _matformer_config,
400        )
401    }
402    fn num_layers(&self, config: &str) -> Result<usize> {
403        Self::get_loader(config)?.num_layers(config)
404    }
405    fn layer_sizes_in_bytes(
406        &self,
407        config: &str,
408        dtype: DType,
409        weight_pack_factor: usize,
410        _matformer_config: Option<&MatformerSliceConfig>,
411    ) -> Result<Vec<usize>> {
412        Self::get_loader(config)?.layer_sizes_in_bytes(
413            config,
414            dtype,
415            weight_pack_factor,
416            _matformer_config,
417        )
418    }
419    fn mapped_max_act_size_elems(
420        &self,
421        config: &str,
422        params: &super::AutoDeviceMapParams,
423    ) -> Result<usize> {
424        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
425    }
426    fn non_mapped_max_act_size_elems(
427        &self,
428        _config: &str,
429        _params: &AutoDeviceMapParams,
430    ) -> Result<usize> {
431        Ok(0)
432    }
433    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
434        Self::get_loader(config)?.model_config(config)
435    }
436}
437
438// ======================== Mistral loader
439
440pub struct MistralLoader;
441
442impl NormalModelLoader for MistralLoader {
443    fn load(
444        &self,
445        config: &str,
446        vb: ShardedVarBuilder,
447        normal_loading_metadata: NormalLoadingMetadata,
448        attention_mechanism: AttentionImplementation,
449    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
450        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
451        Ok(Box::new(models::mistral::Model::new(
452            &cfg,
453            vb,
454            self.is_gptx(config)?,
455            normal_loading_metadata,
456            attention_mechanism,
457        )?))
458    }
459    fn load_xlora(
460        &self,
461        config: &str,
462        vb: ShardedVarBuilder,
463        lora_config: &[((String, String), LoraConfig)],
464        xlora_config: Option<XLoraConfig>,
465        xlora_ordering: Ordering,
466        normal_loading_metadata: NormalLoadingMetadata,
467        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
468    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
469        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
470        Ok(Box::new(xlora_models::XLoraMistral::new(
471            &cfg,
472            vb,
473            lora_config,
474            xlora_config,
475            xlora_ordering,
476            self.is_gptx(config)?,
477            normal_loading_metadata,
478            preload_adapters,
479        )?))
480    }
481    fn is_gptx(&self, _: &str) -> Result<bool> {
482        Ok(true)
483    }
484    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
485        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
486        Ok(Box::new(cfg))
487    }
488}
489
490impl IsqModelLoader for MistralLoader {
491    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
492        Ok(vec![
493            Regex::new(r"lm_head\.(weight|bias)$")?,
494            // Attention
495            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
496            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
497            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
498            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
499            // MLP
500            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
501            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
502            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
503        ])
504    }
505    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
506        self.isq_layer_regexes(config)
507    }
508}
509
510impl DeviceMappedModelLoader for MistralLoader {
511    fn mapped_max_act_size_elems(
512        &self,
513        config: &str,
514        params: &AutoDeviceMapParams,
515    ) -> Result<usize> {
516        let AutoDeviceMapParams::Text {
517            max_seq_len,
518            max_batch_size,
519        } = params
520        else {
521            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
522        };
523
524        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
525
526        Ok(
527            max_batch_size
528                * cfg.num_attention_heads
529                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
530        )
531    }
532    fn non_mapped_max_act_size_elems(
533        &self,
534        _config: &str,
535        _params: &AutoDeviceMapParams,
536    ) -> Result<usize> {
537        Ok(0)
538    }
539
540    fn non_mapped_size_in_bytes(
541        &self,
542        config: &str,
543        dtype: DType,
544        weight_pack_factor: usize,
545        _matformer_config: Option<&MatformerSliceConfig>,
546    ) -> Result<usize> {
547        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
548
549        let elems = {
550            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
551            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
552            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
553                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
554            } else {
555                0
556            };
557            let norm = cfg.hidden_size;
558            embed_tokens + lm_head + norm
559        };
560        Ok(elems * dtype.size_in_bytes())
561    }
562
563    fn layer_sizes_in_bytes(
564        &self,
565        config: &str,
566        dtype: DType,
567        weight_pack_factor: usize,
568        _matformer_config: Option<&MatformerSliceConfig>,
569    ) -> Result<Vec<usize>> {
570        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
571
572        let per_layer_elems = {
573            let input_layernorm = cfg.hidden_size;
574            let post_attention_layernorm = cfg.hidden_size;
575
576            let size_in = cfg.hidden_size;
577            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
578            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
579            let q_proj = size_in * size_q / weight_pack_factor;
580            let k_proj = size_in * size_kv / weight_pack_factor;
581            let v_proj = size_in * size_kv / weight_pack_factor;
582            let o_proj = size_q * size_in / weight_pack_factor;
583
584            let h_size = cfg.hidden_size;
585            let i_size = cfg.intermediate_size;
586            let gate_proj = h_size * i_size / weight_pack_factor;
587            let up_proj = h_size * i_size / weight_pack_factor;
588            let down_proj = i_size * h_size / weight_pack_factor;
589
590            input_layernorm
591                + post_attention_layernorm
592                + q_proj
593                + k_proj
594                + v_proj
595                + o_proj
596                + gate_proj
597                + up_proj
598                + down_proj
599        };
600        Ok(vec![
601            per_layer_elems * dtype.size_in_bytes();
602            cfg.num_hidden_layers
603        ])
604    }
605
606    fn num_layers(&self, config: &str) -> Result<usize> {
607        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
608        Ok(cfg.num_hidden_layers)
609    }
610
611    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
612        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
613
614        let cfg = ModelConfigMetadata {
615            max_seq_len: cfg.max_position_embeddings,
616            num_layers: cfg.num_hidden_layers,
617            hidden_size: cfg.hidden_size,
618            num_kv_heads: cfg.num_key_value_heads,
619            num_attn_heads: cfg.num_attention_heads,
620            sliding_window: cfg.sliding_window,
621            k_head_dim: cfg.head_dim(),
622            v_head_dim: cfg.head_dim(),
623        };
624
625        Ok(Box::new(cfg))
626    }
627}
628
629// ======================== Gemma loader
630
631/// [`NormalLoader`] for a Gemma model.
632///
633/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
634pub struct GemmaLoader;
635
636impl NormalModelLoader for GemmaLoader {
637    fn load(
638        &self,
639        config: &str,
640        vb: ShardedVarBuilder,
641        normal_loading_metadata: NormalLoadingMetadata,
642        attention_mechanism: AttentionImplementation,
643    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
644        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
645
646        Ok(Box::new(models::gemma::Model::new(
647            &cfg,
648            vb,
649            self.is_gptx(config)?,
650            normal_loading_metadata,
651            attention_mechanism,
652        )?))
653    }
654    fn load_xlora(
655        &self,
656        config: &str,
657        vb: ShardedVarBuilder,
658        lora_config: &[((String, String), LoraConfig)],
659        xlora_config: Option<XLoraConfig>,
660        xlora_ordering: Ordering,
661        normal_loading_metadata: NormalLoadingMetadata,
662        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
663    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
664        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
665
666        Ok(Box::new(xlora_models::XLoraGemma::new(
667            &cfg,
668            vb,
669            lora_config,
670            xlora_config,
671            xlora_ordering,
672            self.is_gptx(config)?,
673            normal_loading_metadata,
674            preload_adapters,
675        )?))
676    }
677    fn is_gptx(&self, _: &str) -> Result<bool> {
678        Ok(true)
679    }
680    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
681        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
682        Ok(Box::new(cfg))
683    }
684}
685
686impl IsqModelLoader for GemmaLoader {
687    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
688        Ok(vec![
689            Regex::new(r"lm_head\.(weight|bias)$")?,
690            // Attention
691            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
692            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
693            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
694            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
695            // MLP
696            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
697            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
698            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
699        ])
700    }
701    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
702        self.isq_layer_regexes(config)
703    }
704}
705
706impl DeviceMappedModelLoader for GemmaLoader {
707    fn mapped_max_act_size_elems(
708        &self,
709        config: &str,
710        params: &AutoDeviceMapParams,
711    ) -> Result<usize> {
712        let AutoDeviceMapParams::Text {
713            max_seq_len,
714            max_batch_size,
715        } = params
716        else {
717            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
718        };
719
720        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
721
722        Ok(
723            max_batch_size
724                * cfg.num_attention_heads
725                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
726        )
727    }
728    fn non_mapped_max_act_size_elems(
729        &self,
730        _config: &str,
731        _params: &AutoDeviceMapParams,
732    ) -> Result<usize> {
733        Ok(0)
734    }
735
736    fn non_mapped_size_in_bytes(
737        &self,
738        config: &str,
739        dtype: DType,
740        weight_pack_factor: usize,
741        _matformer_config: Option<&MatformerSliceConfig>,
742    ) -> Result<usize> {
743        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
744
745        let elems = {
746            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
747            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
748            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
749                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
750            } else {
751                0
752            };
753            let norm = cfg.hidden_size;
754            embed_tokens + lm_head + norm
755        };
756        Ok(elems * dtype.size_in_bytes())
757    }
758
759    fn layer_sizes_in_bytes(
760        &self,
761        config: &str,
762        dtype: DType,
763        weight_pack_factor: usize,
764        _matformer_config: Option<&MatformerSliceConfig>,
765    ) -> Result<Vec<usize>> {
766        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
767
768        let per_layer_elems = {
769            let input_layernorm = cfg.hidden_size;
770            let post_attention_layernorm = cfg.hidden_size;
771
772            let size_in = cfg.hidden_size;
773            let size_q = cfg.head_dim * cfg.num_attention_heads;
774            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
775            let q_proj =
776                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
777            let k_proj =
778                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
779            let v_proj =
780                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
781            let o_proj =
782                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
783
784            let h_size = cfg.hidden_size;
785            let i_size = cfg.intermediate_size;
786            let gate_proj = h_size * i_size / weight_pack_factor;
787            let up_proj = h_size * i_size / weight_pack_factor;
788            let down_proj = i_size * h_size / weight_pack_factor;
789
790            input_layernorm
791                + post_attention_layernorm
792                + q_proj
793                + k_proj
794                + v_proj
795                + o_proj
796                + gate_proj
797                + up_proj
798                + down_proj
799        };
800        Ok(vec![
801            per_layer_elems * dtype.size_in_bytes();
802            cfg.num_hidden_layers
803        ])
804    }
805
806    fn num_layers(&self, config: &str) -> Result<usize> {
807        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
808        Ok(cfg.num_hidden_layers)
809    }
810
811    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
812        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
813
814        let cfg = ModelConfigMetadata {
815            max_seq_len: cfg.max_position_embeddings,
816            num_layers: cfg.num_hidden_layers,
817            hidden_size: cfg.hidden_size,
818            num_kv_heads: cfg.num_key_value_heads,
819            num_attn_heads: cfg.num_attention_heads,
820            sliding_window: None,
821            k_head_dim: cfg.head_dim,
822            v_head_dim: cfg.head_dim,
823        };
824
825        Ok(Box::new(cfg))
826    }
827}
828
829// ======================== Llama loader
830
831/// [`NormalLoader`] for a Llama model.
832///
833/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
834pub struct LlamaLoader;
835
836impl NormalModelLoader for LlamaLoader {
837    fn load(
838        &self,
839        config: &str,
840        vb: ShardedVarBuilder,
841        normal_loading_metadata: NormalLoadingMetadata,
842        attention_mechanism: AttentionImplementation,
843    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
844        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
845
846        Ok(Box::new(models::llama::Llama::new(
847            &cfg,
848            vb,
849            self.is_gptx(config)?,
850            normal_loading_metadata,
851            attention_mechanism,
852        )?))
853    }
854    fn load_xlora(
855        &self,
856        config: &str,
857        vb: ShardedVarBuilder,
858        lora_config: &[((String, String), LoraConfig)],
859        xlora_config: Option<XLoraConfig>,
860        xlora_ordering: Ordering,
861        normal_loading_metadata: NormalLoadingMetadata,
862        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
863    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
864        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
865
866        Ok(Box::new(xlora_models::XLoraLlama::new(
867            &cfg,
868            vb,
869            lora_config,
870            xlora_config,
871            xlora_ordering,
872            self.is_gptx(config)?,
873            normal_loading_metadata,
874            preload_adapters,
875        )?))
876    }
877    fn is_gptx(&self, _: &str) -> Result<bool> {
878        Ok(true)
879    }
880    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
881        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
882        Ok(Box::new(cfg))
883    }
884}
885
886impl IsqModelLoader for LlamaLoader {
887    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
888        Ok(vec![
889            Regex::new(r"lm_head\.(weight|bias)$")?,
890            // Attention
891            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
892            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
893            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
894            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
895            // MLP
896            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
897            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
898            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
899        ])
900    }
901    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
902        self.isq_layer_regexes(config)
903    }
904}
905
906impl DeviceMappedModelLoader for LlamaLoader {
907    fn mapped_max_act_size_elems(
908        &self,
909        config: &str,
910        params: &AutoDeviceMapParams,
911    ) -> Result<usize> {
912        let AutoDeviceMapParams::Text {
913            max_seq_len,
914            max_batch_size,
915        } = params
916        else {
917            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
918        };
919
920        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
921
922        Ok(
923            max_batch_size
924                * cfg.num_attention_heads
925                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
926        )
927    }
928    fn non_mapped_max_act_size_elems(
929        &self,
930        _config: &str,
931        _params: &AutoDeviceMapParams,
932    ) -> Result<usize> {
933        Ok(0)
934    }
935
936    fn non_mapped_size_in_bytes(
937        &self,
938        config: &str,
939        dtype: DType,
940        weight_pack_factor: usize,
941        _matformer_config: Option<&MatformerSliceConfig>,
942    ) -> Result<usize> {
943        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
944
945        let elems = {
946            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
947            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
948            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
949                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
950            } else {
951                0
952            };
953            let norm = cfg.hidden_size;
954            embed_tokens + lm_head + norm
955        };
956        Ok(elems * dtype.size_in_bytes())
957    }
958
959    fn layer_sizes_in_bytes(
960        &self,
961        config: &str,
962        dtype: DType,
963        weight_pack_factor: usize,
964        _matformer_config: Option<&MatformerSliceConfig>,
965    ) -> Result<Vec<usize>> {
966        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
967
968        let per_layer_elems = {
969            let input_layernorm = cfg.hidden_size;
970            let post_attention_layernorm = cfg.hidden_size;
971
972            let size_in = cfg.hidden_size;
973            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
974            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
975            let q_proj = size_in * size_q / weight_pack_factor;
976            let k_proj = size_in * size_kv / weight_pack_factor;
977            let v_proj = size_in * size_kv / weight_pack_factor;
978            let o_proj = size_q * size_in / weight_pack_factor;
979
980            let h_size = cfg.hidden_size;
981            let i_size = cfg.intermediate_size;
982            let gate_proj = h_size * i_size / weight_pack_factor;
983            let up_proj = h_size * i_size / weight_pack_factor;
984            let down_proj = i_size * h_size / weight_pack_factor;
985
986            input_layernorm
987                + post_attention_layernorm
988                + q_proj
989                + k_proj
990                + v_proj
991                + o_proj
992                + gate_proj
993                + up_proj
994                + down_proj
995        };
996        Ok(vec![
997            per_layer_elems * dtype.size_in_bytes();
998            cfg.num_hidden_layers
999        ])
1000    }
1001
1002    fn num_layers(&self, config: &str) -> Result<usize> {
1003        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1004
1005        Ok(cfg.num_hidden_layers)
1006    }
1007    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1008        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1009
1010        let cfg = ModelConfigMetadata {
1011            max_seq_len: cfg.max_position_embeddings,
1012            num_layers: cfg.num_hidden_layers,
1013            hidden_size: cfg.hidden_size,
1014            num_kv_heads: cfg.num_key_value_heads,
1015            num_attn_heads: cfg.num_attention_heads,
1016            sliding_window: None,
1017            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1018            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1019        };
1020
1021        Ok(Box::new(cfg))
1022    }
1023}
1024
1025// ======================== Mixtral loader
1026
1027pub struct MixtralLoader;
1028
1029impl NormalModelLoader for MixtralLoader {
1030    fn load(
1031        &self,
1032        config: &str,
1033        vb: ShardedVarBuilder,
1034        normal_loading_metadata: NormalLoadingMetadata,
1035        attention_mechanism: AttentionImplementation,
1036    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1037        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1038
1039        Ok(Box::new(models::mixtral::Model::new(
1040            &cfg,
1041            vb,
1042            self.is_gptx(config)?,
1043            normal_loading_metadata,
1044            attention_mechanism,
1045        )?))
1046    }
1047    fn load_xlora(
1048        &self,
1049        config: &str,
1050        vb: ShardedVarBuilder,
1051        lora_config: &[((String, String), LoraConfig)],
1052        xlora_config: Option<XLoraConfig>,
1053        xlora_ordering: Ordering,
1054        normal_loading_metadata: NormalLoadingMetadata,
1055        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1056    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1057        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1058
1059        Ok(Box::new(xlora_models::XLoraMixtral::new(
1060            &cfg,
1061            vb,
1062            lora_config,
1063            xlora_config,
1064            xlora_ordering,
1065            self.is_gptx(config)?,
1066            normal_loading_metadata,
1067            preload_adapters,
1068        )?))
1069    }
1070    fn is_gptx(&self, _: &str) -> Result<bool> {
1071        Ok(true)
1072    }
1073    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1074        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1075
1076        Ok(Box::new(cfg))
1077    }
1078}
1079
1080impl IsqModelLoader for MixtralLoader {
1081    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1082        Ok(vec![
1083            Regex::new(r"lm_head\.(weight|bias)$")?,
1084            // Attention
1085            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1086            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1087            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1088            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1089            // Experts
1090            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1091            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1092            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1093            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1094        ])
1095    }
1096    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1097        self.isq_layer_regexes(config)
1098    }
1099}
1100
1101impl DeviceMappedModelLoader for MixtralLoader {
1102    fn mapped_max_act_size_elems(
1103        &self,
1104        config: &str,
1105        params: &AutoDeviceMapParams,
1106    ) -> Result<usize> {
1107        let AutoDeviceMapParams::Text {
1108            max_seq_len,
1109            max_batch_size,
1110        } = params
1111        else {
1112            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1113        };
1114
1115        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1116
1117        Ok(
1118            max_batch_size
1119                * cfg.num_attention_heads
1120                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1121        )
1122    }
1123    fn non_mapped_max_act_size_elems(
1124        &self,
1125        _config: &str,
1126        _params: &AutoDeviceMapParams,
1127    ) -> Result<usize> {
1128        Ok(0)
1129    }
1130
1131    fn non_mapped_size_in_bytes(
1132        &self,
1133        config: &str,
1134        dtype: DType,
1135        weight_pack_factor: usize,
1136        _matformer_config: Option<&MatformerSliceConfig>,
1137    ) -> Result<usize> {
1138        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1139
1140        let elems = {
1141            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1142            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1143            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1144                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1145            } else {
1146                0
1147            };
1148            let norm = cfg.hidden_size;
1149            embed_tokens + lm_head + norm
1150        };
1151        Ok(elems * dtype.size_in_bytes())
1152    }
1153
1154    fn layer_sizes_in_bytes(
1155        &self,
1156        config: &str,
1157        dtype: DType,
1158        weight_pack_factor: usize,
1159        _matformer_config: Option<&MatformerSliceConfig>,
1160    ) -> Result<Vec<usize>> {
1161        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1162
1163        let per_layer_elems = {
1164            let input_layernorm = cfg.hidden_size;
1165            let post_attention_layernorm = cfg.hidden_size;
1166
1167            let size_in = cfg.hidden_size;
1168            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1169            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1170            let q_proj = size_in * size_q / weight_pack_factor;
1171            let k_proj = size_in * size_kv / weight_pack_factor;
1172            let v_proj = size_in * size_kv / weight_pack_factor;
1173            let o_proj = size_q * size_in / weight_pack_factor;
1174
1175            let moe_block = {
1176                let gate = cfg.hidden_size * cfg.num_local_experts;
1177                // Assume quantizing weight pack factor
1178                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1179                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1180                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1181                gate + cfg.num_local_experts * w1
1182                    + cfg.num_local_experts * w2
1183                    + cfg.num_local_experts * w3
1184            };
1185
1186            input_layernorm
1187                + post_attention_layernorm
1188                + q_proj
1189                + k_proj
1190                + v_proj
1191                + o_proj
1192                + moe_block
1193        };
1194        Ok(vec![
1195            per_layer_elems * dtype.size_in_bytes();
1196            cfg.num_hidden_layers
1197        ])
1198    }
1199
1200    fn num_layers(&self, config: &str) -> Result<usize> {
1201        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1202
1203        Ok(cfg.num_hidden_layers)
1204    }
1205
1206    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1207        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1208
1209        let cfg = ModelConfigMetadata {
1210            max_seq_len: cfg.max_position_embeddings,
1211            num_layers: cfg.num_hidden_layers,
1212            hidden_size: cfg.hidden_size,
1213            num_kv_heads: cfg.num_key_value_heads,
1214            num_attn_heads: cfg.num_attention_heads,
1215            sliding_window: cfg.sliding_window,
1216            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1217            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1218        };
1219
1220        Ok(Box::new(cfg))
1221    }
1222}
1223
1224// ======================== Phi2 loader
1225
1226/// [`NormalLoader`] for a Phi 2 model.
1227///
1228/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1229pub struct Phi2Loader;
1230
1231impl NormalModelLoader for Phi2Loader {
1232    fn load(
1233        &self,
1234        config: &str,
1235        vb: ShardedVarBuilder,
1236        normal_loading_metadata: NormalLoadingMetadata,
1237        attention_mechanism: AttentionImplementation,
1238    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1239        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1240
1241        Ok(Box::new(models::phi2::Model::new(
1242            &cfg,
1243            vb,
1244            self.is_gptx(config)?,
1245            normal_loading_metadata,
1246            attention_mechanism,
1247        )?))
1248    }
1249    fn load_xlora(
1250        &self,
1251        config: &str,
1252        vb: ShardedVarBuilder,
1253        lora_config: &[((String, String), LoraConfig)],
1254        xlora_config: Option<XLoraConfig>,
1255        xlora_ordering: Ordering,
1256        normal_loading_metadata: NormalLoadingMetadata,
1257        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1258    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1259        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1260
1261        Ok(Box::new(xlora_models::XLoraPhi2::new(
1262            &cfg,
1263            vb,
1264            lora_config,
1265            xlora_config,
1266            xlora_ordering,
1267            self.is_gptx(config)?,
1268            normal_loading_metadata,
1269            preload_adapters,
1270        )?))
1271    }
1272    fn is_gptx(&self, _: &str) -> Result<bool> {
1273        Ok(true)
1274    }
1275    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1276        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1277
1278        Ok(Box::new(cfg))
1279    }
1280}
1281
1282impl IsqModelLoader for Phi2Loader {
1283    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1284        Ok(vec![
1285            Regex::new(r"lm_head\.(weight|bias)$")?,
1286            // Attention
1287            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1288            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1289            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1290            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1291            // MLP
1292            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1293            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1294        ])
1295    }
1296    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1297        self.isq_layer_regexes(config)
1298    }
1299}
1300
1301impl DeviceMappedModelLoader for Phi2Loader {
1302    fn mapped_max_act_size_elems(
1303        &self,
1304        config: &str,
1305        params: &AutoDeviceMapParams,
1306    ) -> Result<usize> {
1307        let AutoDeviceMapParams::Text {
1308            max_seq_len,
1309            max_batch_size,
1310        } = params
1311        else {
1312            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1313        };
1314
1315        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1316
1317        Ok(
1318            max_batch_size
1319                * cfg.num_attention_heads
1320                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1321        )
1322    }
1323    fn non_mapped_max_act_size_elems(
1324        &self,
1325        _config: &str,
1326        _params: &AutoDeviceMapParams,
1327    ) -> Result<usize> {
1328        Ok(0)
1329    }
1330
1331    fn non_mapped_size_in_bytes(
1332        &self,
1333        config: &str,
1334        dtype: DType,
1335        weight_pack_factor: usize,
1336        _matformer_config: Option<&MatformerSliceConfig>,
1337    ) -> Result<usize> {
1338        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1339
1340        let elems = {
1341            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1342            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1343            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1344                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1345            } else {
1346                0
1347            };
1348            let norm = cfg.hidden_size;
1349            embed_tokens + lm_head + norm
1350        };
1351        Ok(elems * dtype.size_in_bytes())
1352    }
1353
1354    fn layer_sizes_in_bytes(
1355        &self,
1356        config: &str,
1357        dtype: DType,
1358        weight_pack_factor: usize,
1359        _matformer_config: Option<&MatformerSliceConfig>,
1360    ) -> Result<Vec<usize>> {
1361        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1362
1363        let per_layer_elems = {
1364            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1365
1366            let size_in = cfg.hidden_size;
1367            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1368            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1369            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1370            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1371            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1372            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1373            let (q_norm, k_norm) = if cfg.qk_layernorm {
1374                (cfg.head_dim(), cfg.head_dim())
1375            } else {
1376                (0, 0)
1377            };
1378
1379            let h_size = cfg.hidden_size;
1380            let i_size = cfg.intermediate_size;
1381            let fc1 = h_size * i_size / weight_pack_factor;
1382            let fc2 = h_size * i_size / weight_pack_factor;
1383
1384            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1385        };
1386        Ok(vec![
1387            per_layer_elems * dtype.size_in_bytes();
1388            cfg.num_hidden_layers
1389        ])
1390    }
1391
1392    fn num_layers(&self, config: &str) -> Result<usize> {
1393        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1394
1395        Ok(cfg.num_hidden_layers)
1396    }
1397
1398    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1399        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1400
1401        let cfg = ModelConfigMetadata {
1402            max_seq_len: cfg.max_position_embeddings,
1403            num_layers: cfg.num_hidden_layers,
1404            hidden_size: cfg.hidden_size,
1405            num_kv_heads: cfg.num_key_value_heads(),
1406            num_attn_heads: cfg.num_attention_heads,
1407            sliding_window: None,
1408            k_head_dim: cfg.head_dim(),
1409            v_head_dim: cfg.head_dim(),
1410        };
1411
1412        Ok(Box::new(cfg))
1413    }
1414}
1415
1416// ======================== Phi3 loader
1417
1418/// [`NormalLoader`] for a Phi 3 model.
1419///
1420/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1421pub struct Phi3Loader;
1422
1423impl NormalModelLoader for Phi3Loader {
1424    fn load(
1425        &self,
1426        config: &str,
1427        vb: ShardedVarBuilder,
1428        normal_loading_metadata: NormalLoadingMetadata,
1429        attention_mechanism: AttentionImplementation,
1430    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1431        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1432
1433        Ok(Box::new(models::phi3::Model::new(
1434            &cfg,
1435            vb,
1436            self.is_gptx(config)?,
1437            normal_loading_metadata,
1438            attention_mechanism,
1439        )?))
1440    }
1441    fn load_xlora(
1442        &self,
1443        config: &str,
1444        vb: ShardedVarBuilder,
1445        lora_config: &[((String, String), LoraConfig)],
1446        xlora_config: Option<XLoraConfig>,
1447        xlora_ordering: Ordering,
1448        normal_loading_metadata: NormalLoadingMetadata,
1449        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1450    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1451        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1452
1453        Ok(Box::new(xlora_models::XLoraPhi3::new(
1454            &cfg,
1455            vb,
1456            lora_config,
1457            xlora_config,
1458            xlora_ordering,
1459            self.is_gptx(config)?,
1460            normal_loading_metadata,
1461            preload_adapters,
1462        )?))
1463    }
1464    fn is_gptx(&self, _: &str) -> Result<bool> {
1465        Ok(true)
1466    }
1467    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1468        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1469
1470        Ok(Box::new(cfg))
1471    }
1472}
1473
1474impl IsqModelLoader for Phi3Loader {
1475    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1476        Ok(vec![
1477            Regex::new(r"lm_head\.(weight|bias)$")?,
1478            // Attention
1479            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1480            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1481            // MLP
1482            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1483            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1484            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1485        ])
1486    }
1487    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1488        self.isq_layer_regexes(config)
1489    }
1490}
1491
1492impl DeviceMappedModelLoader for Phi3Loader {
1493    fn mapped_max_act_size_elems(
1494        &self,
1495        config: &str,
1496        params: &AutoDeviceMapParams,
1497    ) -> Result<usize> {
1498        let AutoDeviceMapParams::Text {
1499            max_seq_len,
1500            max_batch_size,
1501        } = params
1502        else {
1503            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1504        };
1505
1506        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1507
1508        Ok(
1509            max_batch_size
1510                * cfg.num_attention_heads
1511                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1512        )
1513    }
1514    fn non_mapped_max_act_size_elems(
1515        &self,
1516        _config: &str,
1517        _params: &AutoDeviceMapParams,
1518    ) -> Result<usize> {
1519        Ok(0)
1520    }
1521
1522    fn non_mapped_size_in_bytes(
1523        &self,
1524        config: &str,
1525        dtype: DType,
1526        weight_pack_factor: usize,
1527        _matformer_config: Option<&MatformerSliceConfig>,
1528    ) -> Result<usize> {
1529        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1530
1531        let elems = {
1532            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1533            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1534            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1535                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1536            } else {
1537                0
1538            };
1539            let norm = cfg.hidden_size;
1540            embed_tokens + lm_head + norm
1541        };
1542        Ok(elems * dtype.size_in_bytes())
1543    }
1544
1545    fn layer_sizes_in_bytes(
1546        &self,
1547        config: &str,
1548        dtype: DType,
1549        weight_pack_factor: usize,
1550        _matformer_config: Option<&MatformerSliceConfig>,
1551    ) -> Result<Vec<usize>> {
1552        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1553
1554        let per_layer_elems = {
1555            let input_layernorm = cfg.hidden_size;
1556            let post_attention_layernorm = cfg.hidden_size;
1557
1558            let size_in = cfg.hidden_size;
1559            let head_dim = cfg.head_dim();
1560            let op_size =
1561                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1562            let qkv_proj = size_in * op_size / weight_pack_factor;
1563            let o_proj =
1564                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1565
1566            let h_size = cfg.hidden_size;
1567            let i_size = cfg.intermediate_size;
1568            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1569            let down_proj = h_size * i_size / weight_pack_factor;
1570
1571            input_layernorm
1572                + post_attention_layernorm
1573                + qkv_proj
1574                + o_proj
1575                + gate_up_proj
1576                + down_proj
1577        };
1578        Ok(vec![
1579            per_layer_elems * dtype.size_in_bytes();
1580            cfg.num_hidden_layers
1581        ])
1582    }
1583
1584    fn num_layers(&self, config: &str) -> Result<usize> {
1585        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1586
1587        Ok(cfg.num_hidden_layers)
1588    }
1589
1590    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1591        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1592
1593        let cfg = ModelConfigMetadata {
1594            max_seq_len: cfg.max_position_embeddings,
1595            num_layers: cfg.num_hidden_layers,
1596            hidden_size: cfg.hidden_size,
1597            num_kv_heads: cfg.num_key_value_heads,
1598            num_attn_heads: cfg.num_attention_heads,
1599            sliding_window: cfg.sliding_window,
1600            k_head_dim: cfg.head_dim(),
1601            v_head_dim: cfg.head_dim(),
1602        };
1603
1604        Ok(Box::new(cfg))
1605    }
1606}
1607
1608// ======================== Qwen2 loader
1609
1610/// [`NormalLoader`] for a Qwen 2 model.
1611///
1612/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1613pub struct Qwen2Loader;
1614
1615impl NormalModelLoader for Qwen2Loader {
1616    fn load(
1617        &self,
1618        config: &str,
1619        vb: ShardedVarBuilder,
1620        normal_loading_metadata: NormalLoadingMetadata,
1621        attention_mechanism: AttentionImplementation,
1622    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1623        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1624
1625        Ok(Box::new(models::qwen2::Model::new(
1626            &cfg,
1627            vb,
1628            self.is_gptx(config)?,
1629            normal_loading_metadata,
1630            attention_mechanism,
1631        )?))
1632    }
1633    fn load_xlora(
1634        &self,
1635        _config: &str,
1636        _vb: ShardedVarBuilder,
1637        _lora_config: &[((String, String), LoraConfig)],
1638        _xlora_config: Option<XLoraConfig>,
1639        _xlora_ordering: Ordering,
1640        _normal_loading_metadata: NormalLoadingMetadata,
1641        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1642    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1643        todo!()
1644    }
1645    fn is_gptx(&self, _: &str) -> Result<bool> {
1646        Ok(true)
1647    }
1648    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1649        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1650
1651        Ok(Box::new(cfg))
1652    }
1653}
1654
1655impl IsqModelLoader for Qwen2Loader {
1656    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1657        Ok(vec![
1658            Regex::new(r"lm_head\.(weight|bias)$")?,
1659            // Attention
1660            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1661            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1662            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1663            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1664            // MLP
1665            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1666            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1667            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1668        ])
1669    }
1670    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1671        self.isq_layer_regexes(config)
1672    }
1673}
1674
1675impl DeviceMappedModelLoader for Qwen2Loader {
1676    fn mapped_max_act_size_elems(
1677        &self,
1678        config: &str,
1679        params: &AutoDeviceMapParams,
1680    ) -> Result<usize> {
1681        let AutoDeviceMapParams::Text {
1682            max_seq_len,
1683            max_batch_size,
1684        } = params
1685        else {
1686            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1687        };
1688
1689        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1690
1691        Ok(
1692            max_batch_size
1693                * cfg.num_attention_heads
1694                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1695        )
1696    }
1697    fn non_mapped_max_act_size_elems(
1698        &self,
1699        _config: &str,
1700        _params: &AutoDeviceMapParams,
1701    ) -> Result<usize> {
1702        Ok(0)
1703    }
1704
1705    fn non_mapped_size_in_bytes(
1706        &self,
1707        config: &str,
1708        dtype: DType,
1709        weight_pack_factor: usize,
1710        _matformer_config: Option<&MatformerSliceConfig>,
1711    ) -> Result<usize> {
1712        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1713
1714        let elems = {
1715            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1716            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1717            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1718                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1719            } else {
1720                0
1721            };
1722            let norm = cfg.hidden_size;
1723            embed_tokens + lm_head + norm
1724        };
1725        Ok(elems * dtype.size_in_bytes())
1726    }
1727
1728    fn layer_sizes_in_bytes(
1729        &self,
1730        config: &str,
1731        dtype: DType,
1732        weight_pack_factor: usize,
1733        _matformer_config: Option<&MatformerSliceConfig>,
1734    ) -> Result<Vec<usize>> {
1735        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1736
1737        let per_layer_elems = {
1738            let input_layernorm = cfg.hidden_size;
1739            let post_attention_layernorm = cfg.hidden_size;
1740
1741            let size_in = cfg.hidden_size;
1742            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1743            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1744            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1745            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1746            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1747            let o_proj = size_q * size_in / weight_pack_factor;
1748
1749            let h_size = cfg.hidden_size;
1750            let i_size = cfg.intermediate_size;
1751            let gate_proj = h_size * i_size / weight_pack_factor;
1752            let up_proj = h_size * i_size / weight_pack_factor;
1753            let down_proj = i_size * h_size / weight_pack_factor;
1754
1755            input_layernorm
1756                + post_attention_layernorm
1757                + q_proj
1758                + k_proj
1759                + v_proj
1760                + o_proj
1761                + gate_proj
1762                + up_proj
1763                + down_proj
1764        };
1765        Ok(vec![
1766            per_layer_elems * dtype.size_in_bytes();
1767            cfg.num_hidden_layers
1768        ])
1769    }
1770
1771    fn num_layers(&self, config: &str) -> Result<usize> {
1772        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1773
1774        Ok(cfg.num_hidden_layers)
1775    }
1776
1777    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1778        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1779
1780        let cfg = ModelConfigMetadata {
1781            max_seq_len: cfg.max_position_embeddings,
1782            num_layers: cfg.num_hidden_layers,
1783            hidden_size: cfg.hidden_size,
1784            num_kv_heads: cfg.num_key_value_heads,
1785            num_attn_heads: cfg.num_attention_heads,
1786            sliding_window: cfg.sliding_window,
1787            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1788            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1789        };
1790
1791        Ok(Box::new(cfg))
1792    }
1793}
1794
1795// ======================== Gemma2 loader
1796
1797/// [`NormalLoader`] for a Gemma2 model.
1798///
1799/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1800pub struct Gemma2Loader;
1801
1802impl NormalModelLoader for Gemma2Loader {
1803    fn load(
1804        &self,
1805        config: &str,
1806        vb: ShardedVarBuilder,
1807        normal_loading_metadata: NormalLoadingMetadata,
1808        attention_mechanism: AttentionImplementation,
1809    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1810        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1811
1812        Ok(Box::new(models::gemma2::Model::new(
1813            &cfg,
1814            vb,
1815            self.is_gptx(config)?,
1816            normal_loading_metadata,
1817            attention_mechanism,
1818        )?))
1819    }
1820    fn load_xlora(
1821        &self,
1822        config: &str,
1823        vb: ShardedVarBuilder,
1824        lora_config: &[((String, String), LoraConfig)],
1825        xlora_config: Option<XLoraConfig>,
1826        xlora_ordering: Ordering,
1827        normal_loading_metadata: NormalLoadingMetadata,
1828        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1829    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1830        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1831
1832        Ok(Box::new(xlora_models::XLoraGemma2::new(
1833            &cfg,
1834            vb,
1835            lora_config,
1836            xlora_config,
1837            xlora_ordering,
1838            self.is_gptx(config)?,
1839            normal_loading_metadata,
1840            preload_adapters,
1841        )?))
1842    }
1843    fn is_gptx(&self, _: &str) -> Result<bool> {
1844        Ok(true)
1845    }
1846    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1847        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1848
1849        Ok(Box::new(cfg))
1850    }
1851}
1852
1853impl IsqModelLoader for Gemma2Loader {
1854    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1855        Ok(vec![
1856            Regex::new(r"lm_head\.(weight|bias)$")?,
1857            // Attention
1858            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1859            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1860            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1861            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1862            // MLP
1863            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1864            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1865            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1866        ])
1867    }
1868    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1869        self.isq_layer_regexes(config)
1870    }
1871}
1872
1873impl DeviceMappedModelLoader for Gemma2Loader {
1874    fn mapped_max_act_size_elems(
1875        &self,
1876        config: &str,
1877        params: &AutoDeviceMapParams,
1878    ) -> Result<usize> {
1879        let AutoDeviceMapParams::Text {
1880            max_seq_len,
1881            max_batch_size,
1882        } = params
1883        else {
1884            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1885        };
1886
1887        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1888
1889        Ok(
1890            max_batch_size
1891                * cfg.num_attention_heads
1892                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1893        )
1894    }
1895    fn non_mapped_max_act_size_elems(
1896        &self,
1897        _config: &str,
1898        _params: &AutoDeviceMapParams,
1899    ) -> Result<usize> {
1900        Ok(0)
1901    }
1902
1903    fn non_mapped_size_in_bytes(
1904        &self,
1905        config: &str,
1906        dtype: DType,
1907        weight_pack_factor: usize,
1908        _matformer_config: Option<&MatformerSliceConfig>,
1909    ) -> Result<usize> {
1910        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1911
1912        let elems = {
1913            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1914            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1915            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1916                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1917            } else {
1918                0
1919            };
1920            let norm = cfg.hidden_size;
1921            embed_tokens + lm_head + norm
1922        };
1923        Ok(elems * dtype.size_in_bytes())
1924    }
1925
1926    fn layer_sizes_in_bytes(
1927        &self,
1928        config: &str,
1929        dtype: DType,
1930        weight_pack_factor: usize,
1931        _matformer_config: Option<&MatformerSliceConfig>,
1932    ) -> Result<Vec<usize>> {
1933        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1934
1935        let per_layer_elems = {
1936            let input_layernorm = cfg.hidden_size;
1937            let post_attention_layernorm = cfg.hidden_size;
1938
1939            let size_in = cfg.hidden_size;
1940            let size_q = cfg.head_dim * cfg.num_attention_heads;
1941            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1942            let q_proj =
1943                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1944            let k_proj =
1945                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1946            let v_proj =
1947                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1948            let o_proj =
1949                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1950
1951            let h_size = cfg.hidden_size;
1952            let i_size = cfg.intermediate_size;
1953            let gate_proj = h_size * i_size / weight_pack_factor;
1954            let up_proj = h_size * i_size / weight_pack_factor;
1955            let down_proj = i_size * h_size / weight_pack_factor;
1956
1957            input_layernorm
1958                + post_attention_layernorm
1959                + q_proj
1960                + k_proj
1961                + v_proj
1962                + o_proj
1963                + gate_proj
1964                + up_proj
1965                + down_proj
1966        };
1967        Ok(vec![
1968            per_layer_elems * dtype.size_in_bytes();
1969            cfg.num_hidden_layers
1970        ])
1971    }
1972
1973    fn num_layers(&self, config: &str) -> Result<usize> {
1974        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1975
1976        Ok(cfg.num_hidden_layers)
1977    }
1978    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1979        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1980
1981        let cfg = ModelConfigMetadata {
1982            max_seq_len: cfg.max_position_embeddings,
1983            num_layers: cfg.num_hidden_layers,
1984            hidden_size: cfg.hidden_size,
1985            num_kv_heads: cfg.num_key_value_heads,
1986            num_attn_heads: cfg.num_attention_heads,
1987            sliding_window: None, // None to be more forgiving, some do not
1988            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1989            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1990        };
1991
1992        Ok(Box::new(cfg))
1993    }
1994}
1995
1996// ======================== Starcoder2 loader
1997
1998/// [`NormalLoader`] for a Starcoder2 model.
1999///
2000/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2001pub struct Starcoder2Loader;
2002
2003impl NormalModelLoader for Starcoder2Loader {
2004    fn load(
2005        &self,
2006        config: &str,
2007        vb: ShardedVarBuilder,
2008        normal_loading_metadata: NormalLoadingMetadata,
2009        attention_mechanism: AttentionImplementation,
2010    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2011        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2012
2013        Ok(Box::new(models::starcoder2::Model::new(
2014            &cfg,
2015            vb,
2016            self.is_gptx(config)?,
2017            normal_loading_metadata,
2018            attention_mechanism,
2019        )?))
2020    }
2021    fn load_xlora(
2022        &self,
2023        config: &str,
2024        vb: ShardedVarBuilder,
2025        lora_config: &[((String, String), LoraConfig)],
2026        xlora_config: Option<XLoraConfig>,
2027        xlora_ordering: Ordering,
2028        normal_loading_metadata: NormalLoadingMetadata,
2029        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2030    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2031        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2032
2033        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2034            &cfg,
2035            vb,
2036            lora_config,
2037            xlora_config,
2038            xlora_ordering,
2039            self.is_gptx(config)?,
2040            normal_loading_metadata,
2041            preload_adapters,
2042        )?))
2043    }
2044    fn is_gptx(&self, _: &str) -> Result<bool> {
2045        Ok(true)
2046    }
2047    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2048        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2049
2050        Ok(Box::new(cfg))
2051    }
2052}
2053
2054impl IsqModelLoader for Starcoder2Loader {
2055    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2056        Ok(vec![
2057            Regex::new(r"lm_head\.(weight|bias)$")?,
2058            // Attention
2059            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2060            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2061            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2062            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2063            // MLP
2064            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2065            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2066        ])
2067    }
2068    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2069        self.isq_layer_regexes(config)
2070    }
2071}
2072
2073impl DeviceMappedModelLoader for Starcoder2Loader {
2074    fn mapped_max_act_size_elems(
2075        &self,
2076        config: &str,
2077        params: &AutoDeviceMapParams,
2078    ) -> Result<usize> {
2079        let AutoDeviceMapParams::Text {
2080            max_seq_len,
2081            max_batch_size,
2082        } = params
2083        else {
2084            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2085        };
2086
2087        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2088
2089        Ok(
2090            max_batch_size
2091                * cfg.num_attention_heads
2092                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2093        )
2094    }
2095    fn non_mapped_max_act_size_elems(
2096        &self,
2097        _config: &str,
2098        _params: &AutoDeviceMapParams,
2099    ) -> Result<usize> {
2100        Ok(0)
2101    }
2102
2103    fn non_mapped_size_in_bytes(
2104        &self,
2105        config: &str,
2106        dtype: DType,
2107        weight_pack_factor: usize,
2108        _matformer_config: Option<&MatformerSliceConfig>,
2109    ) -> Result<usize> {
2110        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2111
2112        let elems = {
2113            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2114            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2115            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2116                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2117            } else {
2118                0
2119            };
2120            let norm = cfg.hidden_size + cfg.hidden_size;
2121            embed_tokens + lm_head + norm
2122        };
2123        Ok(elems * dtype.size_in_bytes())
2124    }
2125
2126    fn layer_sizes_in_bytes(
2127        &self,
2128        config: &str,
2129        dtype: DType,
2130        weight_pack_factor: usize,
2131        _matformer_config: Option<&MatformerSliceConfig>,
2132    ) -> Result<Vec<usize>> {
2133        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2134
2135        let per_layer_elems = {
2136            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2137            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2138
2139            let size_in = cfg.hidden_size;
2140            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2141            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2142            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2143            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2144            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2145            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2146
2147            let h_size = cfg.hidden_size;
2148            let i_size = cfg.intermediate_size;
2149            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2150            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2151
2152            input_layernorm
2153                + post_attention_layernorm
2154                + q_proj
2155                + k_proj
2156                + v_proj
2157                + o_proj
2158                + fc1
2159                + fc2
2160        };
2161        Ok(vec![
2162            per_layer_elems * dtype.size_in_bytes();
2163            cfg.num_hidden_layers
2164        ])
2165    }
2166
2167    fn num_layers(&self, config: &str) -> Result<usize> {
2168        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2169
2170        Ok(cfg.num_hidden_layers)
2171    }
2172
2173    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2174        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2175
2176        let cfg = ModelConfigMetadata {
2177            max_seq_len: cfg.max_position_embeddings,
2178            num_layers: cfg.num_hidden_layers,
2179            hidden_size: cfg.hidden_size,
2180            num_kv_heads: cfg.num_key_value_heads,
2181            num_attn_heads: cfg.num_attention_heads,
2182            sliding_window: cfg.sliding_window,
2183            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2184            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2185        };
2186
2187        Ok(Box::new(cfg))
2188    }
2189}
2190
2191// ======================== Phi3 loader
2192
2193/// [`NormalLoader`] for a Phi 3.5 MoE model.
2194///
2195/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2196pub struct Phi3_5MoELoader;
2197
2198impl NormalModelLoader for Phi3_5MoELoader {
2199    fn load(
2200        &self,
2201        config: &str,
2202        vb: ShardedVarBuilder,
2203        normal_loading_metadata: NormalLoadingMetadata,
2204        attention_mechanism: AttentionImplementation,
2205    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2206        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2207
2208        Ok(Box::new(models::phi3_5_moe::Model::new(
2209            &cfg,
2210            vb,
2211            self.is_gptx(config)?,
2212            normal_loading_metadata,
2213            attention_mechanism,
2214        )?))
2215    }
2216    fn load_xlora(
2217        &self,
2218        config: &str,
2219        vb: ShardedVarBuilder,
2220        lora_config: &[((String, String), LoraConfig)],
2221        xlora_config: Option<XLoraConfig>,
2222        xlora_ordering: Ordering,
2223        normal_loading_metadata: NormalLoadingMetadata,
2224        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2225    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2226        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2227
2228        Ok(Box::new(xlora_models::XLoraPhi3::new(
2229            &cfg,
2230            vb,
2231            lora_config,
2232            xlora_config,
2233            xlora_ordering,
2234            self.is_gptx(config)?,
2235            normal_loading_metadata,
2236            preload_adapters,
2237        )?))
2238    }
2239    fn is_gptx(&self, _: &str) -> Result<bool> {
2240        Ok(true)
2241    }
2242    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2243        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2244
2245        Ok(Box::new(cfg))
2246    }
2247}
2248
2249impl IsqModelLoader for Phi3_5MoELoader {
2250    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2251        Ok(vec![
2252            Regex::new(r"lm_head\.(weight|bias)$")?,
2253            // Attention
2254            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2255            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2256            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2257            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2258            // MLP
2259            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2260            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2261            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2262        ])
2263    }
2264    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2265        self.isq_layer_regexes(config)
2266    }
2267
2268    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2269        Ok(vec![
2270            Regex::new(r"lm_head\.(weight|bias)$")?,
2271            // MLP
2272            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2273            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2274            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2275        ])
2276    }
2277    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2278        self.isq_layer_regexes_moqe(config)
2279    }
2280}
2281
2282impl DeviceMappedModelLoader for Phi3_5MoELoader {
2283    fn mapped_max_act_size_elems(
2284        &self,
2285        config: &str,
2286        params: &AutoDeviceMapParams,
2287    ) -> Result<usize> {
2288        let AutoDeviceMapParams::Text {
2289            max_seq_len,
2290            max_batch_size,
2291        } = params
2292        else {
2293            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2294        };
2295
2296        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2297
2298        Ok(
2299            max_batch_size
2300                * cfg.num_attention_heads
2301                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2302        )
2303    }
2304    fn non_mapped_max_act_size_elems(
2305        &self,
2306        _config: &str,
2307        _params: &AutoDeviceMapParams,
2308    ) -> Result<usize> {
2309        Ok(0)
2310    }
2311
2312    fn non_mapped_size_in_bytes(
2313        &self,
2314        config: &str,
2315        dtype: DType,
2316        weight_pack_factor: usize,
2317        _matformer_config: Option<&MatformerSliceConfig>,
2318    ) -> Result<usize> {
2319        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2320
2321        let elems = {
2322            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2323            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2324            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2325                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2326            } else {
2327                0
2328            };
2329            let norm = cfg.hidden_size;
2330            embed_tokens + lm_head + norm
2331        };
2332        Ok(elems * dtype.size_in_bytes())
2333    }
2334
2335    fn layer_sizes_in_bytes(
2336        &self,
2337        config: &str,
2338        dtype: DType,
2339        weight_pack_factor: usize,
2340        _matformer_config: Option<&MatformerSliceConfig>,
2341    ) -> Result<Vec<usize>> {
2342        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2343
2344        let per_layer_elems = {
2345            let input_layernorm = cfg.hidden_size;
2346            let post_attention_layernorm = cfg.hidden_size;
2347
2348            let size_in = cfg.hidden_size;
2349            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2350            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2351            let q_proj =
2352                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2353            let k_proj =
2354                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2355            let v_proj =
2356                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2357            let o_proj =
2358                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2359
2360            let moe_block = {
2361                let gate = cfg.hidden_size * cfg.num_local_experts;
2362                // Assume quantizing weight pack factor
2363                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2364                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2365                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2366                gate + cfg.num_local_experts * w1
2367                    + cfg.num_local_experts * w2
2368                    + cfg.num_local_experts * w3
2369            };
2370
2371            input_layernorm
2372                + post_attention_layernorm
2373                + q_proj
2374                + k_proj
2375                + v_proj
2376                + o_proj
2377                + moe_block
2378        };
2379        Ok(vec![
2380            per_layer_elems * dtype.size_in_bytes();
2381            cfg.num_hidden_layers
2382        ])
2383    }
2384
2385    fn num_layers(&self, config: &str) -> Result<usize> {
2386        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2387
2388        Ok(cfg.num_hidden_layers)
2389    }
2390
2391    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2392        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2393
2394        let cfg = ModelConfigMetadata {
2395            max_seq_len: cfg.max_position_embeddings,
2396            num_layers: cfg.num_hidden_layers,
2397            hidden_size: cfg.hidden_size,
2398            num_kv_heads: cfg.num_key_value_heads,
2399            num_attn_heads: cfg.num_attention_heads,
2400            sliding_window: cfg.sliding_window,
2401            k_head_dim: cfg.head_dim(),
2402            v_head_dim: cfg.head_dim(),
2403        };
2404
2405        Ok(Box::new(cfg))
2406    }
2407}
2408
2409/// [`NormalLoader`] for a DeepSeekV2 model.
2410///
2411/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2412pub struct DeepSeekV2Loader;
2413
2414impl NormalModelLoader for DeepSeekV2Loader {
2415    fn load(
2416        &self,
2417        config: &str,
2418        vb: ShardedVarBuilder,
2419        normal_loading_metadata: NormalLoadingMetadata,
2420        attention_mechanism: AttentionImplementation,
2421    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2422        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2423
2424        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2425            &cfg,
2426            vb,
2427            self.is_gptx(config)?,
2428            normal_loading_metadata,
2429            attention_mechanism,
2430        )?))
2431    }
2432    fn load_xlora(
2433        &self,
2434        _config: &str,
2435        _vb: ShardedVarBuilder,
2436        _lora_config: &[((String, String), LoraConfig)],
2437        _xlora_config: Option<XLoraConfig>,
2438        _xlora_ordering: Ordering,
2439        _normal_loading_metadata: NormalLoadingMetadata,
2440        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2441    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2442        todo!()
2443    }
2444    fn is_gptx(&self, _: &str) -> Result<bool> {
2445        Ok(true)
2446    }
2447    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2448        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2449        Ok(Box::new(cfg))
2450    }
2451}
2452
2453impl IsqModelLoader for DeepSeekV2Loader {
2454    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2455        let mut data = vec![
2456            Regex::new(r"lm_head\.(weight|bias)$")?,
2457            // Attention
2458            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2459            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2460            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2461        ];
2462        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2463        if cfg.q_lora_rank.is_some() {
2464            data.extend(vec![
2465                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2466                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2467            ]);
2468        } else {
2469            data.push(Regex::new(
2470                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2471            )?);
2472        }
2473        for layer_idx in 0..cfg.num_hidden_layers {
2474            if cfg.n_routed_experts.is_some()
2475                && layer_idx >= cfg.first_k_dense_replace
2476                && layer_idx % cfg.moe_layer_freq == 0
2477            {
2478                for i in 0..cfg.n_routed_experts.unwrap() {
2479                    data.extend(vec![
2480                        Regex::new(&format!(
2481                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2482                        ))?,
2483                        Regex::new(&format!(
2484                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2485                        ))?,
2486                        Regex::new(&format!(
2487                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2488                        ))?,
2489                    ]);
2490                }
2491                if cfg.n_shared_experts.is_some() {
2492                    data.extend(vec![
2493                        Regex::new(&format!(
2494                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2495                        ))?,
2496                        Regex::new(&format!(
2497                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2498                        ))?,
2499                        Regex::new(&format!(
2500                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2501                        ))?,
2502                    ]);
2503                }
2504            } else {
2505                data.extend(vec![
2506                    Regex::new(&format!(
2507                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2508                    ))?,
2509                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2510                    Regex::new(&format!(
2511                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2512                    ))?,
2513                ]);
2514            };
2515        }
2516        Ok(data)
2517    }
2518    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2519        self.isq_layer_regexes(config)
2520    }
2521
2522    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2523        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2524        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2525        for layer_idx in 0..cfg.num_hidden_layers {
2526            if cfg.n_routed_experts.is_some()
2527                && layer_idx >= cfg.first_k_dense_replace
2528                && layer_idx % cfg.moe_layer_freq == 0
2529            {
2530                for i in 0..cfg.n_routed_experts.unwrap() {
2531                    data.extend(vec![
2532                        Regex::new(&format!(
2533                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2534                        ))?,
2535                        Regex::new(&format!(
2536                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2537                        ))?,
2538                        Regex::new(&format!(
2539                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2540                        ))?,
2541                    ]);
2542                }
2543                if cfg.n_shared_experts.is_some() {
2544                    data.extend(vec![
2545                        Regex::new(&format!(
2546                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2547                        ))?,
2548                        Regex::new(&format!(
2549                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2550                        ))?,
2551                        Regex::new(&format!(
2552                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2553                        ))?,
2554                    ]);
2555                }
2556            } else {
2557                data.extend(vec![
2558                    Regex::new(&format!(
2559                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2560                    ))?,
2561                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2562                    Regex::new(&format!(
2563                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2564                    ))?,
2565                ]);
2566            };
2567        }
2568        Ok(data)
2569    }
2570    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2571        self.isq_layer_regexes_moqe(config)
2572    }
2573}
2574
2575impl DeviceMappedModelLoader for DeepSeekV2Loader {
2576    fn mapped_max_act_size_elems(
2577        &self,
2578        config: &str,
2579        params: &AutoDeviceMapParams,
2580    ) -> Result<usize> {
2581        let AutoDeviceMapParams::Text {
2582            max_seq_len,
2583            max_batch_size,
2584        } = params
2585        else {
2586            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2587        };
2588
2589        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2590
2591        Ok(
2592            max_batch_size
2593                * cfg.num_attention_heads
2594                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2595        )
2596    }
2597    fn non_mapped_max_act_size_elems(
2598        &self,
2599        _config: &str,
2600        _params: &AutoDeviceMapParams,
2601    ) -> Result<usize> {
2602        Ok(0)
2603    }
2604
2605    fn non_mapped_size_in_bytes(
2606        &self,
2607        config: &str,
2608        dtype: DType,
2609        weight_pack_factor: usize,
2610        _matformer_config: Option<&MatformerSliceConfig>,
2611    ) -> Result<usize> {
2612        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2613        let elems = {
2614            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2615            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2616            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2617                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2618            } else {
2619                0
2620            };
2621            let norm = cfg.hidden_size;
2622            embed_tokens + lm_head + norm
2623        };
2624        Ok(elems * dtype.size_in_bytes())
2625    }
2626
2627    fn layer_sizes_in_bytes(
2628        &self,
2629        config: &str,
2630        dtype: DType,
2631        weight_pack_factor: usize,
2632        _matformer_config: Option<&MatformerSliceConfig>,
2633    ) -> Result<Vec<usize>> {
2634        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2635        let mut per_layer_elems = Vec::new();
2636
2637        for layer_idx in 0..cfg.num_hidden_layers {
2638            let input_layernorm = cfg.hidden_size;
2639            let post_attention_layernorm = cfg.hidden_size;
2640
2641            let q_proj = match cfg.q_lora_rank {
2642                Some(lora_rank) => {
2643                    let a = cfg.hidden_size * lora_rank;
2644                    let norm = lora_rank;
2645                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2646                    a + norm + b
2647                }
2648                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2649            };
2650            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2651                / weight_pack_factor
2652                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2653            let kv_a_layernorm = cfg.kv_lora_rank;
2654            let kv_b_proj = cfg.kv_lora_rank
2655                * cfg.num_attention_heads
2656                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2657                / weight_pack_factor;
2658            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2659                / weight_pack_factor
2660                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2661
2662            let moe_block = {
2663                let mut sum = 0;
2664                if cfg.n_routed_experts.is_some()
2665                    && layer_idx >= cfg.first_k_dense_replace
2666                    && layer_idx % cfg.moe_layer_freq == 0
2667                {
2668                    let h_size = cfg.hidden_size;
2669                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2670                        * cfg.n_routed_experts.unwrap();
2671                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2672                        * cfg.n_routed_experts.unwrap();
2673                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2674                        * cfg.n_routed_experts.unwrap();
2675                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2676                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2677                            / weight_pack_factor;
2678                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2679                            / weight_pack_factor;
2680                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2681                            / weight_pack_factor;
2682                        gate_proj + up_proj + down_proj
2683                    } else {
2684                        0
2685                    };
2686                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2687                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2688                } else {
2689                    let h_size = cfg.hidden_size;
2690                    let i_size = cfg.intermediate_size;
2691                    let gate_proj = h_size * i_size / weight_pack_factor;
2692                    let up_proj = h_size * i_size / weight_pack_factor;
2693                    let down_proj = i_size * h_size / weight_pack_factor;
2694                    sum += gate_proj + up_proj + down_proj;
2695                }
2696                sum
2697            };
2698
2699            per_layer_elems.push(
2700                input_layernorm
2701                    + post_attention_layernorm
2702                    + q_proj
2703                    + kv_a_layernorm
2704                    + kv_a_proj_with_mqa
2705                    + kv_b_proj
2706                    + o_proj
2707                    + moe_block,
2708            );
2709        }
2710
2711        Ok(per_layer_elems
2712            .into_iter()
2713            .map(|x| x * dtype.size_in_bytes())
2714            .collect())
2715    }
2716
2717    fn num_layers(&self, config: &str) -> Result<usize> {
2718        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2719        Ok(cfg.num_hidden_layers)
2720    }
2721
2722    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2723        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2724
2725        let cfg = ModelConfigMetadata {
2726            max_seq_len: cfg.max_position_embeddings,
2727            num_layers: cfg.num_hidden_layers,
2728            hidden_size: cfg.hidden_size,
2729            num_kv_heads: cfg.num_attention_heads,
2730            num_attn_heads: cfg.num_attention_heads,
2731            sliding_window: None,
2732            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2733            v_head_dim: cfg.v_head_dim,
2734        };
2735
2736        Ok(Box::new(cfg))
2737    }
2738}
2739
2740/// [`NormalLoader`] for a DeepSeekV3 model.
2741///
2742/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2743pub struct DeepSeekV3Loader;
2744
2745impl NormalModelLoader for DeepSeekV3Loader {
2746    fn load(
2747        &self,
2748        config: &str,
2749        vb: ShardedVarBuilder,
2750        normal_loading_metadata: NormalLoadingMetadata,
2751        attention_mechanism: AttentionImplementation,
2752    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2753        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2754        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2755            &cfg,
2756            vb,
2757            self.is_gptx(config)?,
2758            normal_loading_metadata,
2759            attention_mechanism,
2760        )?))
2761    }
2762    fn load_xlora(
2763        &self,
2764        _config: &str,
2765        _vb: ShardedVarBuilder,
2766        _lora_config: &[((String, String), LoraConfig)],
2767        _xlora_config: Option<XLoraConfig>,
2768        _xlora_ordering: Ordering,
2769        _normal_loading_metadata: NormalLoadingMetadata,
2770        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2771    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2772        todo!()
2773    }
2774    fn is_gptx(&self, _: &str) -> Result<bool> {
2775        Ok(true)
2776    }
2777    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2778        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2779        Ok(Box::new(cfg))
2780    }
2781}
2782
2783impl IsqModelLoader for DeepSeekV3Loader {
2784    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2785        let mut data = vec![
2786            Regex::new(r"lm_head\.(weight|bias)$")?,
2787            // Attention
2788            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2789            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2790            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2791        ];
2792        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2793        if cfg.q_lora_rank.is_some() {
2794            data.extend(vec![
2795                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2796                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2797            ]);
2798        } else {
2799            data.push(Regex::new(
2800                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2801            )?);
2802        }
2803        for layer_idx in 0..cfg.num_hidden_layers {
2804            if cfg.n_routed_experts.is_some()
2805                && layer_idx >= cfg.first_k_dense_replace
2806                && layer_idx % cfg.moe_layer_freq == 0
2807            {
2808                for i in 0..cfg.n_routed_experts.unwrap() {
2809                    data.extend(vec![
2810                        Regex::new(&format!(
2811                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2812                        ))?,
2813                        Regex::new(&format!(
2814                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2815                        ))?,
2816                        Regex::new(&format!(
2817                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2818                        ))?,
2819                    ]);
2820                }
2821                if cfg.n_shared_experts.is_some() {
2822                    data.extend(vec![
2823                        Regex::new(&format!(
2824                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2825                        ))?,
2826                        Regex::new(&format!(
2827                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2828                        ))?,
2829                        Regex::new(&format!(
2830                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2831                        ))?,
2832                    ]);
2833                }
2834            } else {
2835                data.extend(vec![
2836                    Regex::new(&format!(
2837                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2838                    ))?,
2839                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2840                    Regex::new(&format!(
2841                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2842                    ))?,
2843                ]);
2844            };
2845        }
2846        Ok(data)
2847    }
2848    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2849        self.isq_layer_regexes(config)
2850    }
2851
2852    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2853        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2854        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2855        for layer_idx in 0..cfg.num_hidden_layers {
2856            if cfg.n_routed_experts.is_some()
2857                && layer_idx >= cfg.first_k_dense_replace
2858                && layer_idx % cfg.moe_layer_freq == 0
2859            {
2860                for i in 0..cfg.n_routed_experts.unwrap() {
2861                    data.extend(vec![
2862                        Regex::new(&format!(
2863                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2864                        ))?,
2865                        Regex::new(&format!(
2866                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2867                        ))?,
2868                        Regex::new(&format!(
2869                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2870                        ))?,
2871                    ]);
2872                }
2873                if cfg.n_shared_experts.is_some() {
2874                    data.extend(vec![
2875                        Regex::new(&format!(
2876                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2877                        ))?,
2878                        Regex::new(&format!(
2879                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2880                        ))?,
2881                        Regex::new(&format!(
2882                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2883                        ))?,
2884                    ]);
2885                }
2886            } else {
2887                data.extend(vec![
2888                    Regex::new(&format!(
2889                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2890                    ))?,
2891                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2892                    Regex::new(&format!(
2893                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2894                    ))?,
2895                ]);
2896            };
2897        }
2898        Ok(data)
2899    }
2900    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2901        self.isq_layer_regexes_moqe(config)
2902    }
2903}
2904
2905impl DeviceMappedModelLoader for DeepSeekV3Loader {
2906    fn mapped_max_act_size_elems(
2907        &self,
2908        config: &str,
2909        params: &AutoDeviceMapParams,
2910    ) -> Result<usize> {
2911        let AutoDeviceMapParams::Text {
2912            max_seq_len,
2913            max_batch_size,
2914        } = params
2915        else {
2916            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2917        };
2918
2919        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2920
2921        Ok(
2922            max_batch_size
2923                * cfg.num_attention_heads
2924                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2925        )
2926    }
2927    fn non_mapped_max_act_size_elems(
2928        &self,
2929        _config: &str,
2930        _params: &AutoDeviceMapParams,
2931    ) -> Result<usize> {
2932        Ok(0)
2933    }
2934
2935    fn non_mapped_size_in_bytes(
2936        &self,
2937        config: &str,
2938        dtype: DType,
2939        weight_pack_factor: usize,
2940        _matformer_config: Option<&MatformerSliceConfig>,
2941    ) -> Result<usize> {
2942        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2943        let elems = {
2944            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2945            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2946            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2947                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2948            } else {
2949                0
2950            };
2951            let norm = cfg.hidden_size;
2952            embed_tokens + lm_head + norm
2953        };
2954        Ok(elems * dtype.size_in_bytes())
2955    }
2956
2957    fn layer_sizes_in_bytes(
2958        &self,
2959        config: &str,
2960        dtype: DType,
2961        weight_pack_factor: usize,
2962        _matformer_config: Option<&MatformerSliceConfig>,
2963    ) -> Result<Vec<usize>> {
2964        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2965        let mut per_layer_elems = Vec::new();
2966
2967        for layer_idx in 0..cfg.num_hidden_layers {
2968            let input_layernorm = cfg.hidden_size;
2969            let post_attention_layernorm = cfg.hidden_size;
2970
2971            let q_proj = match cfg.q_lora_rank {
2972                Some(lora_rank) => {
2973                    let a = cfg.hidden_size * lora_rank;
2974                    let norm = lora_rank;
2975                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2976                    a + norm + b
2977                }
2978                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2979            };
2980            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2981                / weight_pack_factor
2982                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2983            let kv_a_layernorm = cfg.kv_lora_rank;
2984            let kv_b_proj = cfg.kv_lora_rank
2985                * cfg.num_attention_heads
2986                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2987                / weight_pack_factor;
2988            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2989                / weight_pack_factor
2990                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2991
2992            let moe_block = {
2993                let mut sum = 0;
2994                if cfg.n_routed_experts.is_some()
2995                    && layer_idx >= cfg.first_k_dense_replace
2996                    && layer_idx % cfg.moe_layer_freq == 0
2997                {
2998                    let h_size = cfg.hidden_size;
2999                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3000                        * cfg.n_routed_experts.unwrap();
3001                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3002                        * cfg.n_routed_experts.unwrap();
3003                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3004                        * cfg.n_routed_experts.unwrap();
3005                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3006                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3007                            / weight_pack_factor;
3008                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3009                            / weight_pack_factor;
3010                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3011                            / weight_pack_factor;
3012                        gate_proj + up_proj + down_proj
3013                    } else {
3014                        0
3015                    };
3016                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3017                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3018                } else {
3019                    let h_size = cfg.hidden_size;
3020                    let i_size = cfg.intermediate_size;
3021                    let gate_proj = h_size * i_size / weight_pack_factor;
3022                    let up_proj = h_size * i_size / weight_pack_factor;
3023                    let down_proj = i_size * h_size / weight_pack_factor;
3024                    sum += gate_proj + up_proj + down_proj;
3025                }
3026                sum
3027            };
3028
3029            per_layer_elems.push(
3030                input_layernorm
3031                    + post_attention_layernorm
3032                    + q_proj
3033                    + kv_a_layernorm
3034                    + kv_a_proj_with_mqa
3035                    + kv_b_proj
3036                    + o_proj
3037                    + moe_block,
3038            );
3039        }
3040
3041        Ok(per_layer_elems
3042            .into_iter()
3043            .map(|x| x * dtype.size_in_bytes())
3044            .collect())
3045    }
3046
3047    fn num_layers(&self, config: &str) -> Result<usize> {
3048        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3049        Ok(cfg.num_hidden_layers)
3050    }
3051
3052    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3053        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3054
3055        let cfg = ModelConfigMetadata {
3056            max_seq_len: cfg.max_position_embeddings,
3057            num_layers: cfg.num_hidden_layers,
3058            hidden_size: cfg.hidden_size,
3059            num_kv_heads: cfg.num_attention_heads,
3060            num_attn_heads: cfg.num_attention_heads,
3061            sliding_window: None,
3062            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3063            v_head_dim: cfg.v_head_dim,
3064        };
3065
3066        Ok(Box::new(cfg))
3067    }
3068}
3069
3070/// [`NormalLoader`] for a Qwen 3 model.
3071///
3072/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3073pub struct Qwen3Loader;
3074
3075impl NormalModelLoader for Qwen3Loader {
3076    fn load(
3077        &self,
3078        config: &str,
3079        vb: ShardedVarBuilder,
3080        normal_loading_metadata: NormalLoadingMetadata,
3081        attention_mechanism: AttentionImplementation,
3082    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3083        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3084
3085        Ok(Box::new(models::qwen3::Model::new(
3086            &cfg,
3087            vb,
3088            self.is_gptx(config)?,
3089            normal_loading_metadata,
3090            attention_mechanism,
3091        )?))
3092    }
3093    fn load_xlora(
3094        &self,
3095        _config: &str,
3096        _vb: ShardedVarBuilder,
3097        _lora_config: &[((String, String), LoraConfig)],
3098        _xlora_config: Option<XLoraConfig>,
3099        _xlora_ordering: Ordering,
3100        _normal_loading_metadata: NormalLoadingMetadata,
3101        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3102    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3103        todo!()
3104    }
3105    fn is_gptx(&self, _: &str) -> Result<bool> {
3106        Ok(true)
3107    }
3108    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3109        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3110
3111        Ok(Box::new(cfg))
3112    }
3113}
3114
3115impl IsqModelLoader for Qwen3Loader {
3116    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3117        Ok(vec![
3118            Regex::new(r"lm_head\.(weight|bias)$")?,
3119            // Attention
3120            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3121            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3122            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3123            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3124            // MLP
3125            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3126            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3127            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3128        ])
3129    }
3130    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3131        self.isq_layer_regexes(config)
3132    }
3133}
3134
3135impl DeviceMappedModelLoader for Qwen3Loader {
3136    fn mapped_max_act_size_elems(
3137        &self,
3138        config: &str,
3139        params: &AutoDeviceMapParams,
3140    ) -> Result<usize> {
3141        let AutoDeviceMapParams::Text {
3142            max_seq_len,
3143            max_batch_size,
3144        } = params
3145        else {
3146            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3147        };
3148
3149        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3150
3151        Ok(
3152            max_batch_size
3153                * cfg.num_attention_heads
3154                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3155        )
3156    }
3157    fn non_mapped_max_act_size_elems(
3158        &self,
3159        _config: &str,
3160        _params: &AutoDeviceMapParams,
3161    ) -> Result<usize> {
3162        Ok(0)
3163    }
3164
3165    fn non_mapped_size_in_bytes(
3166        &self,
3167        config: &str,
3168        dtype: DType,
3169        weight_pack_factor: usize,
3170        _matformer_config: Option<&MatformerSliceConfig>,
3171    ) -> Result<usize> {
3172        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3173        let elems = {
3174            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3175            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3176            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3177                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3178            } else {
3179                0
3180            };
3181            let norm = cfg.hidden_size;
3182            embed_tokens + lm_head + norm
3183        };
3184        Ok(elems * dtype.size_in_bytes())
3185    }
3186
3187    fn layer_sizes_in_bytes(
3188        &self,
3189        config: &str,
3190        dtype: DType,
3191        weight_pack_factor: usize,
3192        _matformer_config: Option<&MatformerSliceConfig>,
3193    ) -> Result<Vec<usize>> {
3194        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3195        let per_layer_elems = {
3196            let input_layernorm = cfg.hidden_size;
3197            let post_attention_layernorm = cfg.hidden_size;
3198
3199            let size_in = cfg.hidden_size;
3200            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3201            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3202            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3203            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3204            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3205            let o_proj = size_q * size_in / weight_pack_factor;
3206
3207            let h_size = cfg.hidden_size;
3208            let i_size = cfg.intermediate_size;
3209            let gate_proj = h_size * i_size / weight_pack_factor;
3210            let up_proj = h_size * i_size / weight_pack_factor;
3211            let down_proj = i_size * h_size / weight_pack_factor;
3212
3213            let q_norm = cfg.head_dim();
3214            let k_norm = cfg.head_dim();
3215
3216            input_layernorm
3217                + post_attention_layernorm
3218                + q_proj
3219                + k_proj
3220                + v_proj
3221                + o_proj
3222                + gate_proj
3223                + up_proj
3224                + down_proj
3225                + q_norm
3226                + k_norm
3227        };
3228        Ok(vec![
3229            per_layer_elems * dtype.size_in_bytes();
3230            cfg.num_hidden_layers
3231        ])
3232    }
3233
3234    fn num_layers(&self, config: &str) -> Result<usize> {
3235        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3236        Ok(cfg.num_hidden_layers)
3237    }
3238
3239    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3240        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3241
3242        let cfg = ModelConfigMetadata {
3243            max_seq_len: cfg.max_position_embeddings,
3244            num_layers: cfg.num_hidden_layers,
3245            hidden_size: cfg.hidden_size,
3246            num_kv_heads: cfg.num_key_value_heads,
3247            num_attn_heads: cfg.num_attention_heads,
3248            sliding_window: cfg.sliding_window,
3249            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3250            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3251        };
3252
3253        Ok(Box::new(cfg))
3254    }
3255}
3256
3257/// [`NormalLoader`] for a GLM 4 model.
3258///
3259/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3260pub struct GLM4Loader;
3261
3262impl NormalModelLoader for GLM4Loader {
3263    fn load(
3264        &self,
3265        config: &str,
3266        vb: ShardedVarBuilder,
3267        normal_loading_metadata: NormalLoadingMetadata,
3268        attention_mechanism: AttentionImplementation,
3269    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3270        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3271
3272        Ok(Box::new(models::glm4::Model::new(
3273            &cfg,
3274            vb,
3275            self.is_gptx(config)?,
3276            normal_loading_metadata,
3277            attention_mechanism,
3278        )?))
3279    }
3280    fn load_xlora(
3281        &self,
3282        _config: &str,
3283        _vb: ShardedVarBuilder,
3284        _lora_config: &[((String, String), LoraConfig)],
3285        _xlora_config: Option<XLoraConfig>,
3286        _xlora_ordering: Ordering,
3287        _normal_loading_metadata: NormalLoadingMetadata,
3288        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3289    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3290        todo!()
3291    }
3292    fn is_gptx(&self, _: &str) -> Result<bool> {
3293        Ok(true)
3294    }
3295    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3296        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3297
3298        Ok(Box::new(cfg))
3299    }
3300}
3301
3302impl IsqModelLoader for GLM4Loader {
3303    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3304        Ok(vec![
3305            Regex::new(r"lm_head\.(weight|bias)$")?,
3306            // Attention
3307            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3308            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3309            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3310            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3311            // MLP
3312            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3313            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3314            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3315        ])
3316    }
3317    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3318        self.isq_layer_regexes(config)
3319    }
3320}
3321
3322impl DeviceMappedModelLoader for GLM4Loader {
3323    fn mapped_max_act_size_elems(
3324        &self,
3325        config: &str,
3326        params: &AutoDeviceMapParams,
3327    ) -> Result<usize> {
3328        let AutoDeviceMapParams::Text {
3329            max_seq_len,
3330            max_batch_size,
3331        } = params
3332        else {
3333            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3334        };
3335
3336        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3337
3338        Ok(
3339            max_batch_size
3340                * cfg.num_attention_heads
3341                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3342        )
3343    }
3344    fn non_mapped_max_act_size_elems(
3345        &self,
3346        _config: &str,
3347        _params: &AutoDeviceMapParams,
3348    ) -> Result<usize> {
3349        Ok(0)
3350    }
3351
3352    fn non_mapped_size_in_bytes(
3353        &self,
3354        config: &str,
3355        dtype: DType,
3356        weight_pack_factor: usize,
3357        _matformer_config: Option<&MatformerSliceConfig>,
3358    ) -> Result<usize> {
3359        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3360        let elems = {
3361            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3362            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3363            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3364                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3365            } else {
3366                0
3367            };
3368            let norm = cfg.hidden_size;
3369            embed_tokens + lm_head + norm
3370        };
3371        Ok(elems * dtype.size_in_bytes())
3372    }
3373
3374    fn layer_sizes_in_bytes(
3375        &self,
3376        config: &str,
3377        dtype: DType,
3378        weight_pack_factor: usize,
3379        _matformer_config: Option<&MatformerSliceConfig>,
3380    ) -> Result<Vec<usize>> {
3381        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3382        let per_layer_elems = {
3383            let input_layernorm = cfg.hidden_size;
3384            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3385
3386            let size_in = cfg.hidden_size;
3387            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3388            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3389            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3390            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3391            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3392            let o_proj = size_q * size_in / weight_pack_factor;
3393
3394            let h_size = cfg.hidden_size;
3395            let i_size = cfg.intermediate_size;
3396            let gate_proj = h_size * i_size / weight_pack_factor;
3397            let up_proj = h_size * i_size / weight_pack_factor;
3398            let down_proj = i_size * h_size / weight_pack_factor;
3399
3400            input_layernorm
3401                + post_attention_layernorm
3402                + q_proj
3403                + k_proj
3404                + v_proj
3405                + o_proj
3406                + gate_proj
3407                + up_proj
3408                + down_proj
3409        };
3410        Ok(vec![
3411            per_layer_elems * dtype.size_in_bytes();
3412            cfg.num_hidden_layers
3413        ])
3414    }
3415
3416    fn num_layers(&self, config: &str) -> Result<usize> {
3417        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3418        Ok(cfg.num_hidden_layers)
3419    }
3420
3421    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3422        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3423
3424        let cfg = ModelConfigMetadata {
3425            max_seq_len: cfg.max_position_embeddings,
3426            num_layers: cfg.num_hidden_layers,
3427            hidden_size: cfg.hidden_size,
3428            num_kv_heads: cfg.num_key_value_heads,
3429            num_attn_heads: cfg.num_attention_heads,
3430            sliding_window: cfg.sliding_window,
3431            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3432            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3433        };
3434
3435        Ok(Box::new(cfg))
3436    }
3437}
3438
3439/// [`NormalLoader`] for a GLM 4 MoE Lite model (GLM-4.7-Flash).
3440///
3441/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3442pub struct GLM4MoeLiteLoader;
3443
3444impl NormalModelLoader for GLM4MoeLiteLoader {
3445    fn load(
3446        &self,
3447        config: &str,
3448        vb: ShardedVarBuilder,
3449        normal_loading_metadata: NormalLoadingMetadata,
3450        attention_mechanism: AttentionImplementation,
3451    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3452        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3453        Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3454            &cfg,
3455            vb,
3456            self.is_gptx(config)?,
3457            normal_loading_metadata,
3458            attention_mechanism,
3459        )?))
3460    }
3461    fn load_xlora(
3462        &self,
3463        _config: &str,
3464        _vb: ShardedVarBuilder,
3465        _lora_config: &[((String, String), LoraConfig)],
3466        _xlora_config: Option<XLoraConfig>,
3467        _xlora_ordering: Ordering,
3468        _normal_loading_metadata: NormalLoadingMetadata,
3469        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3470    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3471        todo!()
3472    }
3473    fn is_gptx(&self, _: &str) -> Result<bool> {
3474        Ok(true)
3475    }
3476    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3477        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3478        Ok(Box::new(cfg))
3479    }
3480}
3481
3482impl IsqModelLoader for GLM4MoeLiteLoader {
3483    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3484        let mut data = vec![
3485            Regex::new(r"lm_head\.(weight|bias)$")?,
3486            // Attention (MLA)
3487            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3488            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3489            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3490            // Q LoRA projections
3491            Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3492            Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3493        ];
3494        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3495        for layer_idx in 0..cfg.num_hidden_layers {
3496            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3497                // MoE layer
3498                for i in 0..cfg.n_routed_experts {
3499                    data.extend(vec![
3500                        Regex::new(&format!(
3501                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3502                        ))?,
3503                        Regex::new(&format!(
3504                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3505                        ))?,
3506                        Regex::new(&format!(
3507                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3508                        ))?,
3509                    ]);
3510                }
3511                if cfg.n_shared_experts > 0 {
3512                    data.extend(vec![
3513                        Regex::new(&format!(
3514                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3515                        ))?,
3516                        Regex::new(&format!(
3517                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3518                        ))?,
3519                        Regex::new(&format!(
3520                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3521                        ))?,
3522                    ]);
3523                }
3524            } else {
3525                // Dense MLP layer
3526                data.extend(vec![
3527                    Regex::new(&format!(
3528                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3529                    ))?,
3530                    Regex::new(&format!(
3531                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3532                    ))?,
3533                    Regex::new(&format!(
3534                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3535                    ))?,
3536                ]);
3537            };
3538        }
3539        Ok(data)
3540    }
3541    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3542        self.isq_layer_regexes(config)
3543    }
3544
3545    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3546        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3547        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3548        for layer_idx in 0..cfg.num_hidden_layers {
3549            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3550                // MoE layer
3551                for i in 0..cfg.n_routed_experts {
3552                    data.extend(vec![
3553                        Regex::new(&format!(
3554                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3555                        ))?,
3556                        Regex::new(&format!(
3557                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3558                        ))?,
3559                        Regex::new(&format!(
3560                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3561                        ))?,
3562                    ]);
3563                }
3564                if cfg.n_shared_experts > 0 {
3565                    data.extend(vec![
3566                        Regex::new(&format!(
3567                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3568                        ))?,
3569                        Regex::new(&format!(
3570                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3571                        ))?,
3572                        Regex::new(&format!(
3573                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3574                        ))?,
3575                    ]);
3576                }
3577            } else {
3578                // Dense MLP layer
3579                data.extend(vec![
3580                    Regex::new(&format!(
3581                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3582                    ))?,
3583                    Regex::new(&format!(
3584                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3585                    ))?,
3586                    Regex::new(&format!(
3587                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3588                    ))?,
3589                ]);
3590            };
3591        }
3592        Ok(data)
3593    }
3594    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3595        self.isq_layer_regexes_moqe(config)
3596    }
3597}
3598
3599impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3600    fn mapped_max_act_size_elems(
3601        &self,
3602        config: &str,
3603        params: &AutoDeviceMapParams,
3604    ) -> Result<usize> {
3605        let AutoDeviceMapParams::Text {
3606            max_seq_len,
3607            max_batch_size,
3608        } = params
3609        else {
3610            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3611        };
3612
3613        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3614
3615        Ok(
3616            max_batch_size
3617                * cfg.num_attention_heads
3618                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3619        )
3620    }
3621    fn non_mapped_max_act_size_elems(
3622        &self,
3623        _config: &str,
3624        _params: &AutoDeviceMapParams,
3625    ) -> Result<usize> {
3626        Ok(0)
3627    }
3628
3629    fn non_mapped_size_in_bytes(
3630        &self,
3631        config: &str,
3632        dtype: DType,
3633        weight_pack_factor: usize,
3634        _matformer_config: Option<&MatformerSliceConfig>,
3635    ) -> Result<usize> {
3636        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3637        let elems = {
3638            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3639            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3640            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3641                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3642            } else {
3643                0
3644            };
3645            let norm = cfg.hidden_size;
3646            embed_tokens + lm_head + norm
3647        };
3648        Ok(elems * dtype.size_in_bytes())
3649    }
3650
3651    fn layer_sizes_in_bytes(
3652        &self,
3653        config: &str,
3654        dtype: DType,
3655        weight_pack_factor: usize,
3656        _matformer_config: Option<&MatformerSliceConfig>,
3657    ) -> Result<Vec<usize>> {
3658        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3659        let mut per_layer_elems = Vec::new();
3660
3661        for layer_idx in 0..cfg.num_hidden_layers {
3662            let input_layernorm = cfg.hidden_size;
3663            let post_attention_layernorm = cfg.hidden_size;
3664
3665            // Q LoRA projection
3666            let q_proj = {
3667                let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3668                let norm = cfg.q_lora_rank;
3669                let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3670                    / weight_pack_factor;
3671                a + norm + b
3672            };
3673            let kv_a_proj_with_mqa =
3674                cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3675            let kv_a_layernorm = cfg.kv_lora_rank;
3676            let kv_b_proj = cfg.kv_lora_rank
3677                * cfg.num_attention_heads
3678                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3679                / weight_pack_factor;
3680            let o_proj =
3681                cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3682
3683            let moe_block = {
3684                let mut sum = 0;
3685                if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3686                    // MoE layer
3687                    let h_size = cfg.hidden_size;
3688                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3689                        * cfg.n_routed_experts;
3690                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3691                        * cfg.n_routed_experts;
3692                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3693                        * cfg.n_routed_experts;
3694                    let shared_experts = if cfg.n_shared_experts > 0 {
3695                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3696                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3697                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3698                        gate_proj + up_proj + down_proj
3699                    } else {
3700                        0
3701                    };
3702                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3703                    let e_score_correction_bias = cfg.n_routed_experts;
3704                    sum += gate_proj
3705                        + up_proj
3706                        + down_proj
3707                        + shared_experts
3708                        + gate_weight
3709                        + e_score_correction_bias;
3710                } else {
3711                    // Dense MLP layer
3712                    let h_size = cfg.hidden_size;
3713                    let i_size = cfg.intermediate_size;
3714                    let gate_proj = h_size * i_size / weight_pack_factor;
3715                    let up_proj = h_size * i_size / weight_pack_factor;
3716                    let down_proj = i_size * h_size / weight_pack_factor;
3717                    sum += gate_proj + up_proj + down_proj;
3718                }
3719                sum
3720            };
3721
3722            per_layer_elems.push(
3723                input_layernorm
3724                    + post_attention_layernorm
3725                    + q_proj
3726                    + kv_a_layernorm
3727                    + kv_a_proj_with_mqa
3728                    + kv_b_proj
3729                    + o_proj
3730                    + moe_block,
3731            );
3732        }
3733
3734        Ok(per_layer_elems
3735            .into_iter()
3736            .map(|x| x * dtype.size_in_bytes())
3737            .collect())
3738    }
3739
3740    fn num_layers(&self, config: &str) -> Result<usize> {
3741        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3742        Ok(cfg.num_hidden_layers)
3743    }
3744
3745    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3746        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3747
3748        let cfg = ModelConfigMetadata {
3749            max_seq_len: cfg.max_position_embeddings,
3750            num_layers: cfg.num_hidden_layers,
3751            hidden_size: cfg.hidden_size,
3752            num_kv_heads: cfg.num_attention_heads,
3753            num_attn_heads: cfg.num_attention_heads,
3754            sliding_window: None,
3755            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3756            v_head_dim: cfg.v_head_dim,
3757        };
3758
3759        Ok(Box::new(cfg))
3760    }
3761}
3762
3763/// [`NormalLoader`] for a GLM 4 MoE model (GLM-4.5).
3764///
3765/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3766pub struct GLM4MoeLoader;
3767
3768impl NormalModelLoader for GLM4MoeLoader {
3769    fn load(
3770        &self,
3771        config: &str,
3772        vb: ShardedVarBuilder,
3773        normal_loading_metadata: NormalLoadingMetadata,
3774        attention_mechanism: AttentionImplementation,
3775    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3776        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3777        Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3778            &cfg,
3779            vb,
3780            self.is_gptx(config)?,
3781            normal_loading_metadata,
3782            attention_mechanism,
3783        )?))
3784    }
3785    fn load_xlora(
3786        &self,
3787        _config: &str,
3788        _vb: ShardedVarBuilder,
3789        _lora_config: &[((String, String), LoraConfig)],
3790        _xlora_config: Option<XLoraConfig>,
3791        _xlora_ordering: Ordering,
3792        _normal_loading_metadata: NormalLoadingMetadata,
3793        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3794    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3795        todo!()
3796    }
3797    fn is_gptx(&self, _: &str) -> Result<bool> {
3798        Ok(true)
3799    }
3800    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3801        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3802        Ok(Box::new(cfg))
3803    }
3804}
3805
3806impl IsqModelLoader for GLM4MoeLoader {
3807    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3808        let mut data = vec![
3809            Regex::new(r"lm_head\.(weight|bias)$")?,
3810            // Attention (standard GQA)
3811            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3812            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3813            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3814            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3815        ];
3816        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3817        for layer_idx in 0..cfg.num_hidden_layers {
3818            if layer_idx >= cfg.first_k_dense_replace {
3819                // MoE layer
3820                for i in 0..cfg.n_routed_experts {
3821                    data.extend(vec![
3822                        Regex::new(&format!(
3823                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3824                        ))?,
3825                        Regex::new(&format!(
3826                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3827                        ))?,
3828                        Regex::new(&format!(
3829                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3830                        ))?,
3831                    ]);
3832                }
3833                if cfg.n_shared_experts > 0 {
3834                    data.extend(vec![
3835                        Regex::new(&format!(
3836                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3837                        ))?,
3838                        Regex::new(&format!(
3839                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3840                        ))?,
3841                        Regex::new(&format!(
3842                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3843                        ))?,
3844                    ]);
3845                }
3846            } else {
3847                // Dense MLP layer
3848                data.extend(vec![
3849                    Regex::new(&format!(
3850                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3851                    ))?,
3852                    Regex::new(&format!(
3853                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3854                    ))?,
3855                    Regex::new(&format!(
3856                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3857                    ))?,
3858                ]);
3859            };
3860        }
3861        Ok(data)
3862    }
3863    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3864        self.isq_layer_regexes(config)
3865    }
3866
3867    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3868        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3869        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3870        for layer_idx in 0..cfg.num_hidden_layers {
3871            if layer_idx >= cfg.first_k_dense_replace {
3872                // MoE layer
3873                for i in 0..cfg.n_routed_experts {
3874                    data.extend(vec![
3875                        Regex::new(&format!(
3876                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3877                        ))?,
3878                        Regex::new(&format!(
3879                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3880                        ))?,
3881                        Regex::new(&format!(
3882                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3883                        ))?,
3884                    ]);
3885                }
3886                if cfg.n_shared_experts > 0 {
3887                    data.extend(vec![
3888                        Regex::new(&format!(
3889                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3890                        ))?,
3891                        Regex::new(&format!(
3892                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3893                        ))?,
3894                        Regex::new(&format!(
3895                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3896                        ))?,
3897                    ]);
3898                }
3899            } else {
3900                // Dense MLP layer
3901                data.extend(vec![
3902                    Regex::new(&format!(
3903                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3904                    ))?,
3905                    Regex::new(&format!(
3906                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3907                    ))?,
3908                    Regex::new(&format!(
3909                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3910                    ))?,
3911                ]);
3912            };
3913        }
3914        Ok(data)
3915    }
3916    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3917        self.isq_layer_regexes_moqe(config)
3918    }
3919}
3920
3921impl DeviceMappedModelLoader for GLM4MoeLoader {
3922    fn mapped_max_act_size_elems(
3923        &self,
3924        config: &str,
3925        params: &AutoDeviceMapParams,
3926    ) -> Result<usize> {
3927        let AutoDeviceMapParams::Text {
3928            max_seq_len,
3929            max_batch_size,
3930        } = params
3931        else {
3932            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3933        };
3934
3935        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3936
3937        Ok(
3938            max_batch_size
3939                * cfg.num_attention_heads
3940                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3941        )
3942    }
3943    fn non_mapped_max_act_size_elems(
3944        &self,
3945        _config: &str,
3946        _params: &AutoDeviceMapParams,
3947    ) -> Result<usize> {
3948        Ok(0)
3949    }
3950
3951    fn non_mapped_size_in_bytes(
3952        &self,
3953        config: &str,
3954        dtype: DType,
3955        weight_pack_factor: usize,
3956        _matformer_config: Option<&MatformerSliceConfig>,
3957    ) -> Result<usize> {
3958        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3959        let elems = {
3960            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3961            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3962                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3963            } else {
3964                0
3965            };
3966            let norm = cfg.hidden_size;
3967            embed_tokens + lm_head + norm
3968        };
3969        Ok(elems * dtype.size_in_bytes())
3970    }
3971
3972    fn layer_sizes_in_bytes(
3973        &self,
3974        config: &str,
3975        dtype: DType,
3976        weight_pack_factor: usize,
3977        _matformer_config: Option<&MatformerSliceConfig>,
3978    ) -> Result<Vec<usize>> {
3979        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3980        let mut per_layer_elems = Vec::new();
3981
3982        let head_dim = cfg.head_dim();
3983        for layer_idx in 0..cfg.num_hidden_layers {
3984            let input_layernorm = cfg.hidden_size;
3985            let post_attention_layernorm = cfg.hidden_size;
3986
3987            // Standard GQA attention
3988            let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
3989                + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
3990            let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
3991                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
3992            let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
3993                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
3994            let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
3995
3996            // QK norm if enabled
3997            let qk_norm = if cfg.use_qk_norm {
3998                head_dim * 2 // q_norm + k_norm
3999            } else {
4000                0
4001            };
4002
4003            let moe_block = {
4004                let mut sum = 0;
4005                if layer_idx >= cfg.first_k_dense_replace {
4006                    // MoE layer
4007                    let h_size = cfg.hidden_size;
4008                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4009                        * cfg.n_routed_experts;
4010                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4011                        * cfg.n_routed_experts;
4012                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4013                        * cfg.n_routed_experts;
4014                    let shared_experts = if cfg.n_shared_experts > 0 {
4015                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4016                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4017                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4018                        gate_proj + up_proj + down_proj
4019                    } else {
4020                        0
4021                    };
4022                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4023                    let e_score_correction_bias = cfg.n_routed_experts;
4024                    sum += gate_proj
4025                        + up_proj
4026                        + down_proj
4027                        + shared_experts
4028                        + gate_weight
4029                        + e_score_correction_bias;
4030                } else {
4031                    // Dense MLP layer
4032                    let h_size = cfg.hidden_size;
4033                    let i_size = cfg.intermediate_size;
4034                    let gate_proj = h_size * i_size / weight_pack_factor;
4035                    let up_proj = h_size * i_size / weight_pack_factor;
4036                    let down_proj = i_size * h_size / weight_pack_factor;
4037                    sum += gate_proj + up_proj + down_proj;
4038                }
4039                sum
4040            };
4041
4042            per_layer_elems.push(
4043                input_layernorm
4044                    + post_attention_layernorm
4045                    + q_proj
4046                    + k_proj
4047                    + v_proj
4048                    + o_proj
4049                    + qk_norm
4050                    + moe_block,
4051            );
4052        }
4053
4054        Ok(per_layer_elems
4055            .into_iter()
4056            .map(|x| x * dtype.size_in_bytes())
4057            .collect())
4058    }
4059
4060    fn num_layers(&self, config: &str) -> Result<usize> {
4061        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4062        Ok(cfg.num_hidden_layers)
4063    }
4064
4065    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4066        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4067
4068        let head_dim = cfg.head_dim();
4069        let cfg = ModelConfigMetadata {
4070            max_seq_len: cfg.max_position_embeddings,
4071            num_layers: cfg.num_hidden_layers,
4072            hidden_size: cfg.hidden_size,
4073            num_kv_heads: cfg.num_key_value_heads,
4074            num_attn_heads: cfg.num_attention_heads,
4075            sliding_window: None,
4076            k_head_dim: head_dim,
4077            v_head_dim: head_dim,
4078        };
4079
4080        Ok(Box::new(cfg))
4081    }
4082}
4083
4084/// [`NormalLoader`] for a Qwen 3 MoE model.
4085///
4086/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
4087pub struct Qwen3MoELoader;
4088
4089impl NormalModelLoader for Qwen3MoELoader {
4090    fn load(
4091        &self,
4092        config: &str,
4093        vb: ShardedVarBuilder,
4094        normal_loading_metadata: NormalLoadingMetadata,
4095        attention_mechanism: AttentionImplementation,
4096    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4097        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4098
4099        Ok(Box::new(models::qwen3_moe::Model::new(
4100            &cfg,
4101            vb,
4102            self.is_gptx(config)?,
4103            normal_loading_metadata,
4104            attention_mechanism,
4105        )?))
4106    }
4107    fn load_xlora(
4108        &self,
4109        _config: &str,
4110        _vb: ShardedVarBuilder,
4111        _lora_config: &[((String, String), LoraConfig)],
4112        _xlora_config: Option<XLoraConfig>,
4113        _xlora_ordering: Ordering,
4114        _normal_loading_metadata: NormalLoadingMetadata,
4115        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4116    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4117        todo!()
4118    }
4119    fn is_gptx(&self, _: &str) -> Result<bool> {
4120        Ok(true)
4121    }
4122    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4123        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4124
4125        Ok(Box::new(cfg))
4126    }
4127}
4128
4129impl IsqModelLoader for Qwen3MoELoader {
4130    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4131        Ok(vec![
4132            Regex::new(r"lm_head\.(weight|bias)$")?,
4133            // Attention
4134            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4135            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4136            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4137            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4138            // MLP
4139            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4140            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4141            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4142            // MLP MoE
4143            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4144            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4145            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4146        ])
4147    }
4148    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4149        self.isq_layer_regexes(config)
4150    }
4151    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4152        self.isq_layer_regexes_moqe(config)
4153    }
4154}
4155
4156impl DeviceMappedModelLoader for Qwen3MoELoader {
4157    fn mapped_max_act_size_elems(
4158        &self,
4159        config: &str,
4160        params: &AutoDeviceMapParams,
4161    ) -> Result<usize> {
4162        let AutoDeviceMapParams::Text {
4163            max_seq_len,
4164            max_batch_size,
4165        } = params
4166        else {
4167            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4168        };
4169
4170        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4171
4172        Ok(
4173            max_batch_size
4174                * cfg.num_attention_heads
4175                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4176        )
4177    }
4178    fn non_mapped_max_act_size_elems(
4179        &self,
4180        _config: &str,
4181        _params: &AutoDeviceMapParams,
4182    ) -> Result<usize> {
4183        Ok(0)
4184    }
4185
4186    fn non_mapped_size_in_bytes(
4187        &self,
4188        config: &str,
4189        dtype: DType,
4190        weight_pack_factor: usize,
4191        _matformer_config: Option<&MatformerSliceConfig>,
4192    ) -> Result<usize> {
4193        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4194        let elems = {
4195            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4196            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4197            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4198                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4199            } else {
4200                0
4201            };
4202            let norm = cfg.hidden_size;
4203            embed_tokens + lm_head + norm
4204        };
4205        Ok(elems * dtype.size_in_bytes())
4206    }
4207
4208    fn layer_sizes_in_bytes(
4209        &self,
4210        config: &str,
4211        dtype: DType,
4212        weight_pack_factor: usize,
4213        _matformer_config: Option<&MatformerSliceConfig>,
4214    ) -> Result<Vec<usize>> {
4215        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4216
4217        let mut layer_sizes_in_bytes = Vec::new();
4218        for layer_idx in 0..cfg.num_hidden_layers {
4219            let input_layernorm = cfg.hidden_size;
4220            let post_attention_layernorm = cfg.hidden_size;
4221
4222            let size_in = cfg.hidden_size;
4223            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4224            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4225            let q_proj = size_in * size_q / weight_pack_factor;
4226            let k_proj = size_in * size_kv / weight_pack_factor;
4227            let v_proj = size_in * size_kv / weight_pack_factor;
4228            let o_proj = size_q * size_in / weight_pack_factor;
4229
4230            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4231                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4232            {
4233                let gate_size = cfg.hidden_size * cfg.num_experts;
4234                let expert_size = {
4235                    let h_size = cfg.hidden_size;
4236                    let i_size = cfg.moe_intermediate_size;
4237                    let gate_proj = h_size * i_size / weight_pack_factor;
4238                    let up_proj = h_size * i_size / weight_pack_factor;
4239                    let down_proj = i_size * h_size / weight_pack_factor;
4240                    gate_proj + up_proj + down_proj
4241                };
4242                expert_size * cfg.num_experts + gate_size
4243            } else {
4244                let h_size = cfg.hidden_size;
4245                let i_size = cfg.intermediate_size;
4246                let gate_proj = h_size * i_size / weight_pack_factor;
4247                let up_proj = h_size * i_size / weight_pack_factor;
4248                let down_proj = i_size * h_size / weight_pack_factor;
4249                gate_proj + up_proj + down_proj
4250            };
4251
4252            let q_norm = cfg.head_dim();
4253            let k_norm = cfg.head_dim();
4254
4255            let size_elems = input_layernorm
4256                + post_attention_layernorm
4257                + q_proj
4258                + k_proj
4259                + v_proj
4260                + o_proj
4261                + mlp_size
4262                + q_norm
4263                + k_norm;
4264
4265            let size_in_bytes = size_elems * dtype.size_in_bytes();
4266            layer_sizes_in_bytes.push(size_in_bytes);
4267        }
4268
4269        Ok(layer_sizes_in_bytes)
4270    }
4271
4272    fn num_layers(&self, config: &str) -> Result<usize> {
4273        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4274        Ok(cfg.num_hidden_layers)
4275    }
4276
4277    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4278        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4279
4280        let cfg = ModelConfigMetadata {
4281            max_seq_len: cfg.max_position_embeddings,
4282            num_layers: cfg.num_hidden_layers,
4283            hidden_size: cfg.hidden_size,
4284            num_kv_heads: cfg.num_key_value_heads,
4285            num_attn_heads: cfg.num_attention_heads,
4286            sliding_window: cfg.sliding_window,
4287            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4288            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4289        };
4290
4291        Ok(Box::new(cfg))
4292    }
4293}
4294
4295// ======================== SmolLm3 loader
4296
4297/// [`NormalLoader`] for a SmolLm3 model.
4298///
4299/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
4300pub struct SmolLm3Loader;
4301
4302impl NormalModelLoader for SmolLm3Loader {
4303    fn load(
4304        &self,
4305        config: &str,
4306        vb: ShardedVarBuilder,
4307        normal_loading_metadata: NormalLoadingMetadata,
4308        attention_mechanism: AttentionImplementation,
4309    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4310        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4311
4312        Ok(Box::new(models::smollm3::SmolLm3::new(
4313            &cfg,
4314            vb,
4315            self.is_gptx(config)?,
4316            normal_loading_metadata,
4317            attention_mechanism,
4318        )?))
4319    }
4320    fn load_xlora(
4321        &self,
4322        _config: &str,
4323        _vb: ShardedVarBuilder,
4324        _lora_config: &[((String, String), LoraConfig)],
4325        _xlora_config: Option<XLoraConfig>,
4326        _xlora_ordering: Ordering,
4327        _normal_loading_metadata: NormalLoadingMetadata,
4328        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4329    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4330        todo!()
4331    }
4332    fn is_gptx(&self, _: &str) -> Result<bool> {
4333        Ok(true)
4334    }
4335    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4336        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4337        Ok(Box::new(cfg))
4338    }
4339}
4340
4341impl IsqModelLoader for SmolLm3Loader {
4342    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4343        Ok(vec![
4344            Regex::new(r"lm_head\.(weight|bias)$")?,
4345            // Attention
4346            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4347            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4348            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4349            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4350            // MLP
4351            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4352            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4353            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4354        ])
4355    }
4356    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4357        self.isq_layer_regexes(config)
4358    }
4359}
4360
4361impl DeviceMappedModelLoader for SmolLm3Loader {
4362    fn mapped_max_act_size_elems(
4363        &self,
4364        config: &str,
4365        params: &AutoDeviceMapParams,
4366    ) -> Result<usize> {
4367        let AutoDeviceMapParams::Text {
4368            max_seq_len,
4369            max_batch_size,
4370        } = params
4371        else {
4372            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4373        };
4374
4375        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4376
4377        Ok(
4378            max_batch_size
4379                * cfg.num_attention_heads
4380                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4381        )
4382    }
4383    fn non_mapped_max_act_size_elems(
4384        &self,
4385        _config: &str,
4386        _params: &AutoDeviceMapParams,
4387    ) -> Result<usize> {
4388        Ok(0)
4389    }
4390
4391    fn non_mapped_size_in_bytes(
4392        &self,
4393        config: &str,
4394        dtype: DType,
4395        weight_pack_factor: usize,
4396        _matformer_config: Option<&MatformerSliceConfig>,
4397    ) -> Result<usize> {
4398        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4399
4400        let elems = {
4401            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4402            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4403            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4404                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4405            } else {
4406                0
4407            };
4408            let norm = cfg.hidden_size;
4409            embed_tokens + lm_head + norm
4410        };
4411        Ok(elems * dtype.size_in_bytes())
4412    }
4413
4414    fn layer_sizes_in_bytes(
4415        &self,
4416        config: &str,
4417        dtype: DType,
4418        weight_pack_factor: usize,
4419        _matformer_config: Option<&MatformerSliceConfig>,
4420    ) -> Result<Vec<usize>> {
4421        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4422
4423        let per_layer_elems = {
4424            let input_layernorm = cfg.hidden_size;
4425            let post_attention_layernorm = cfg.hidden_size;
4426
4427            let size_in = cfg.hidden_size;
4428            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4429            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4430            let q_proj = size_in * size_q / weight_pack_factor;
4431            let k_proj = size_in * size_kv / weight_pack_factor;
4432            let v_proj = size_in * size_kv / weight_pack_factor;
4433            let o_proj = size_q * size_in / weight_pack_factor;
4434
4435            let h_size = cfg.hidden_size;
4436            let i_size = cfg.intermediate_size;
4437            let gate_proj = h_size * i_size / weight_pack_factor;
4438            let up_proj = h_size * i_size / weight_pack_factor;
4439            let down_proj = i_size * h_size / weight_pack_factor;
4440
4441            input_layernorm
4442                + post_attention_layernorm
4443                + q_proj
4444                + k_proj
4445                + v_proj
4446                + o_proj
4447                + gate_proj
4448                + up_proj
4449                + down_proj
4450        };
4451        Ok(vec![
4452            per_layer_elems * dtype.size_in_bytes();
4453            cfg.num_hidden_layers
4454        ])
4455    }
4456
4457    fn num_layers(&self, config: &str) -> Result<usize> {
4458        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4459
4460        Ok(cfg.num_hidden_layers)
4461    }
4462    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4463        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4464
4465        let cfg = ModelConfigMetadata {
4466            max_seq_len: cfg.max_position_embeddings,
4467            num_layers: cfg.num_hidden_layers,
4468            hidden_size: cfg.hidden_size,
4469            num_kv_heads: cfg.num_key_value_heads,
4470            num_attn_heads: cfg.num_attention_heads,
4471            sliding_window: None,
4472            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4473            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4474        };
4475
4476        Ok(Box::new(cfg))
4477    }
4478}
4479
4480// ======================== GraniteMoeHybrid loader
4481
4482/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
4483///
4484/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
4485pub struct GraniteMoeHybridLoader;
4486
4487impl NormalModelLoader for GraniteMoeHybridLoader {
4488    fn load(
4489        &self,
4490        config: &str,
4491        vb: ShardedVarBuilder,
4492        normal_loading_metadata: NormalLoadingMetadata,
4493        attention_mechanism: AttentionImplementation,
4494    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4495        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4496
4497        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4498            &cfg,
4499            vb,
4500            self.is_gptx(config)?,
4501            normal_loading_metadata,
4502            attention_mechanism,
4503        )?))
4504    }
4505    fn load_xlora(
4506        &self,
4507        _config: &str,
4508        _vb: ShardedVarBuilder,
4509        _lora_config: &[((String, String), LoraConfig)],
4510        _xlora_config: Option<XLoraConfig>,
4511        _xlora_ordering: Ordering,
4512        _normal_loading_metadata: NormalLoadingMetadata,
4513        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4514    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4515        todo!()
4516    }
4517    fn is_gptx(&self, _: &str) -> Result<bool> {
4518        Ok(true)
4519    }
4520    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4521        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4522        Ok(Box::new(cfg))
4523    }
4524}
4525
4526impl IsqModelLoader for GraniteMoeHybridLoader {
4527    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4528        Ok(vec![
4529            Regex::new(r"lm_head\.(weight|bias)$")?,
4530            // Attention
4531            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4532            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4533            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4534            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4535            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
4536            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4537            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4538        ])
4539    }
4540    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4541        self.isq_layer_regexes(config)
4542    }
4543}
4544
4545impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4546    fn mapped_max_act_size_elems(
4547        &self,
4548        config: &str,
4549        params: &AutoDeviceMapParams,
4550    ) -> Result<usize> {
4551        let AutoDeviceMapParams::Text {
4552            max_seq_len,
4553            max_batch_size,
4554        } = params
4555        else {
4556            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4557        };
4558
4559        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4560
4561        Ok(
4562            max_batch_size
4563                * cfg.num_attention_heads
4564                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4565        )
4566    }
4567    fn non_mapped_max_act_size_elems(
4568        &self,
4569        _config: &str,
4570        _params: &AutoDeviceMapParams,
4571    ) -> Result<usize> {
4572        Ok(0)
4573    }
4574
4575    fn non_mapped_size_in_bytes(
4576        &self,
4577        config: &str,
4578        dtype: DType,
4579        weight_pack_factor: usize,
4580        _matformer_config: Option<&MatformerSliceConfig>,
4581    ) -> Result<usize> {
4582        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4583
4584        let elems = {
4585            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4586            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4587            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4588                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4589            } else {
4590                0
4591            };
4592            let norm = cfg.hidden_size;
4593            embed_tokens + lm_head + norm
4594        };
4595        Ok(elems * dtype.size_in_bytes())
4596    }
4597
4598    fn layer_sizes_in_bytes(
4599        &self,
4600        config: &str,
4601        dtype: DType,
4602        weight_pack_factor: usize,
4603        _matformer_config: Option<&MatformerSliceConfig>,
4604    ) -> Result<Vec<usize>> {
4605        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4606
4607        let per_layer_elems = {
4608            let input_layernorm = cfg.hidden_size;
4609            let post_attention_layernorm = cfg.hidden_size;
4610
4611            let size_in = cfg.hidden_size;
4612            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4613            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4614            let q_proj = size_in * size_q / weight_pack_factor;
4615            let k_proj = size_in * size_kv / weight_pack_factor;
4616            let v_proj = size_in * size_kv / weight_pack_factor;
4617            let o_proj = size_q * size_in / weight_pack_factor;
4618
4619            let h_size = cfg.hidden_size;
4620            let shared_i_size = cfg.shared_intermediate_size();
4621            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
4622            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4623            let output_linear = shared_i_size * h_size / weight_pack_factor;
4624
4625            input_layernorm
4626                + post_attention_layernorm
4627                + q_proj
4628                + k_proj
4629                + v_proj
4630                + o_proj
4631                + input_linear
4632                + output_linear
4633        };
4634        Ok(vec![
4635            per_layer_elems * dtype.size_in_bytes();
4636            cfg.num_hidden_layers
4637        ])
4638    }
4639
4640    fn num_layers(&self, config: &str) -> Result<usize> {
4641        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4642
4643        Ok(cfg.num_hidden_layers)
4644    }
4645    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4646        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4647
4648        let cfg = ModelConfigMetadata {
4649            max_seq_len: cfg.max_position_embeddings,
4650            num_layers: cfg.num_hidden_layers,
4651            hidden_size: cfg.hidden_size,
4652            num_kv_heads: cfg.num_key_value_heads(),
4653            num_attn_heads: cfg.num_attention_heads,
4654            sliding_window: None,
4655            k_head_dim: cfg.head_dim(),
4656            v_head_dim: cfg.head_dim(),
4657        };
4658
4659        Ok(Box::new(cfg))
4660    }
4661}
4662
4663// ======================== GPT-OSS loader
4664
4665/// [`NormalLoader`] for a GPT-OSS model.
4666///
4667/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
4668pub struct GptOssLoader;
4669
4670impl NormalModelLoader for GptOssLoader {
4671    fn load(
4672        &self,
4673        config: &str,
4674        vb: ShardedVarBuilder,
4675        normal_loading_metadata: NormalLoadingMetadata,
4676        attention_mechanism: AttentionImplementation,
4677    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4678        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4679
4680        Ok(Box::new(models::gpt_oss::Model::new(
4681            &cfg,
4682            vb,
4683            self.is_gptx(config)?,
4684            normal_loading_metadata,
4685            attention_mechanism,
4686        )?))
4687    }
4688    fn load_xlora(
4689        &self,
4690        _config: &str,
4691        _vb: ShardedVarBuilder,
4692        _lora_config: &[((String, String), LoraConfig)],
4693        _xlora_config: Option<XLoraConfig>,
4694        _xlora_ordering: Ordering,
4695        _normal_loading_metadata: NormalLoadingMetadata,
4696        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4697    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4698        anyhow::bail!("GPT-OSS does not support X-LoRA")
4699    }
4700    fn is_gptx(&self, _: &str) -> Result<bool> {
4701        Ok(true)
4702    }
4703    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4704        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4705        Ok(Box::new(cfg))
4706    }
4707    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4708        Ok(false)
4709    }
4710}
4711
4712impl IsqModelLoader for GptOssLoader {
4713    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4714        // Only attention layers are ISQ-able - MoE experts are already MXFP4 quantized
4715        Ok(vec![
4716            Regex::new(r"lm_head\.(weight|bias)$")?,
4717            // Attention
4718            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4719            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4720            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4721            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4722        ])
4723    }
4724    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4725        self.isq_layer_regexes(config)
4726    }
4727}
4728
4729impl DeviceMappedModelLoader for GptOssLoader {
4730    fn mapped_max_act_size_elems(
4731        &self,
4732        config: &str,
4733        params: &AutoDeviceMapParams,
4734    ) -> Result<usize> {
4735        let AutoDeviceMapParams::Text {
4736            max_seq_len,
4737            max_batch_size,
4738        } = params
4739        else {
4740            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4741        };
4742
4743        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4744
4745        Ok(
4746            max_batch_size
4747                * cfg.num_attention_heads
4748                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4749        )
4750    }
4751    fn non_mapped_max_act_size_elems(
4752        &self,
4753        _config: &str,
4754        _params: &AutoDeviceMapParams,
4755    ) -> Result<usize> {
4756        Ok(0)
4757    }
4758
4759    fn non_mapped_size_in_bytes(
4760        &self,
4761        config: &str,
4762        dtype: DType,
4763        weight_pack_factor: usize,
4764        _matformer_config: Option<&MatformerSliceConfig>,
4765    ) -> Result<usize> {
4766        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4767
4768        let elems = {
4769            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4770            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4771                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4772            } else {
4773                0
4774            };
4775            let norm = cfg.hidden_size;
4776            embed_tokens + lm_head + norm
4777        };
4778        Ok(elems * dtype.size_in_bytes())
4779    }
4780
4781    fn layer_sizes_in_bytes(
4782        &self,
4783        config: &str,
4784        dtype: DType,
4785        weight_pack_factor: usize,
4786        _matformer_config: Option<&MatformerSliceConfig>,
4787    ) -> Result<Vec<usize>> {
4788        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4789
4790        let per_layer_elems = {
4791            let input_layernorm = cfg.hidden_size;
4792            let post_attention_layernorm = cfg.hidden_size;
4793
4794            let size_in = cfg.hidden_size;
4795            let head_dim = cfg.head_dim();
4796            let size_q = head_dim * cfg.num_attention_heads;
4797            let size_kv = head_dim * cfg.num_key_value_heads;
4798            let q_proj =
4799                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4800            let k_proj =
4801                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4802            let v_proj =
4803                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4804            let o_proj =
4805                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4806
4807            // MoE experts - MXFP4 quantized, so very compact
4808            // gate_up_proj: [num_experts, intermediate_size * 2, hidden_size/2] packed
4809            // down_proj: [num_experts, hidden_size, intermediate_size/2] packed
4810            // At 4 bits per weight, packing factor is 2
4811            let mxfp4_pack = 2;
4812            let gate_up_proj_size =
4813                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4814            let down_proj_size =
4815                cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4816            // Plus scales at 1 byte per 32 elements
4817            let gate_up_scales =
4818                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4819            let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4820            // Plus biases
4821            let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4822            let down_bias = cfg.num_local_experts * cfg.hidden_size;
4823            // Router
4824            let router = cfg.hidden_size * cfg.num_local_experts;
4825            // Sinks per head
4826            let sinks = cfg.num_attention_heads;
4827
4828            input_layernorm
4829                + post_attention_layernorm
4830                + q_proj
4831                + k_proj
4832                + v_proj
4833                + o_proj
4834                + gate_up_proj_size
4835                + down_proj_size
4836                + gate_up_scales
4837                + down_scales
4838                + gate_up_bias
4839                + down_bias
4840                + router
4841                + sinks
4842        };
4843        Ok(vec![
4844            per_layer_elems * dtype.size_in_bytes();
4845            cfg.num_hidden_layers
4846        ])
4847    }
4848
4849    fn num_layers(&self, config: &str) -> Result<usize> {
4850        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4851
4852        Ok(cfg.num_hidden_layers)
4853    }
4854    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4855        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4856
4857        let head_dim = cfg.head_dim();
4858        let cfg = ModelConfigMetadata {
4859            max_seq_len: cfg.max_position_embeddings,
4860            num_layers: cfg.num_hidden_layers,
4861            hidden_size: cfg.hidden_size,
4862            num_kv_heads: cfg.num_key_value_heads,
4863            num_attn_heads: cfg.num_attention_heads,
4864            sliding_window: cfg.sliding_window,
4865            k_head_dim: head_dim,
4866            v_head_dim: head_dim,
4867        };
4868
4869        Ok(Box::new(cfg))
4870    }
4871}