mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "qwen3moe")]
175    Qwen3Moe,
176    #[serde(rename = "smollm3")]
177    SmolLm3,
178    #[serde(rename = "granitemoehybrid")]
179    GraniteMoeHybrid,
180    #[serde(rename = "gpt_oss")]
181    GptOss,
182}
183
184// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
185impl NormalLoaderType {
186    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
187        match name {
188            "MistralForCausalLM" => Ok(Self::Mistral),
189            "MixtralForCausalLM" => Ok(Self::Mixtral),
190            "GemmaForCausalLM" => Ok(Self::Gemma),
191            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
192            "PhiForCausalLM" => Ok(Self::Phi2),
193            "Phi3ForCausalLM" => Ok(Self::Phi3),
194            "LlamaForCausalLM" => Ok(Self::Llama),
195            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
196            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
197            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
198            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
199            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
200            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
201            "Glm4ForCausalLM" => Ok(Self::GLM4),
202            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
203            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
204            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
205            "GptOssForCausalLM" => Ok(Self::GptOss),
206            other => anyhow::bail!(
207                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
208            ),
209        }
210    }
211}
212
213impl FromStr for NormalLoaderType {
214    type Err = String;
215    fn from_str(s: &str) -> Result<Self, Self::Err> {
216        match s {
217            "mistral" => Ok(Self::Mistral),
218            "gemma" => Ok(Self::Gemma),
219            "mixtral" => Ok(Self::Mixtral),
220            "llama" => Ok(Self::Llama),
221            "phi2" => Ok(Self::Phi2),
222            "phi3" => Ok(Self::Phi3),
223            "qwen2" => Ok(Self::Qwen2),
224            "gemma2" => Ok(Self::Gemma2),
225            "starcoder2" => Ok(Self::Starcoder2),
226            "phi3.5moe" => Ok(Self::Phi3_5MoE),
227            "deepseekv2" => Ok(Self::DeepSeekV2),
228            "deepseekv3" => Ok(Self::DeepSeekV3),
229            "qwen3" => Ok(Self::Qwen3),
230            "glm4" => Ok(Self::GLM4),
231            "qwen3moe" => Ok(Self::Qwen3Moe),
232            "smollm3" => Ok(Self::SmolLm3),
233            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
234            "gpt_oss" => Ok(Self::GptOss),
235            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
236        }
237    }
238}
239
240impl Display for NormalLoaderType {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            Self::Gemma => write!(f, "gemma"),
244            Self::Gemma2 => write!(f, "gemma2"),
245            Self::Llama => write!(f, "llama"),
246            Self::Mistral => write!(f, "mistral"),
247            Self::Mixtral => write!(f, "mixtral"),
248            Self::Phi2 => write!(f, "phi2"),
249            Self::Phi3 => write!(f, "phi3"),
250            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
251            Self::Qwen2 => write!(f, "qwen2"),
252            Self::Starcoder2 => write!(f, "starcoder2"),
253            Self::DeepSeekV2 => write!(f, "deepseekv2"),
254            Self::DeepSeekV3 => write!(f, "deepseekv3"),
255            Self::Qwen3 => write!(f, "qwen3"),
256            Self::GLM4 => write!(f, "glm4"),
257            Self::Qwen3Moe => write!(f, "qwen3moe"),
258            Self::SmolLm3 => write!(f, "smollm3"),
259            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
260            Self::GptOss => write!(f, "gpt_oss"),
261        }
262    }
263}
264
265macro_rules! bias_if {
266    ($cond:expr, $size:expr) => {
267        if $cond {
268            $size
269        } else {
270            0
271        }
272    };
273}
274
275/// Load a model based on the Hugging Face Transformers -CausalLM model class
276pub struct AutoNormalLoader;
277
278#[derive(Deserialize)]
279struct AutoNormalLoaderConfig {
280    architectures: Vec<String>,
281}
282
283impl AutoNormalLoader {
284    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
285        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
286        if auto_cfg.architectures.len() != 1 {
287            anyhow::bail!("Expected to have one name for `architectures` config field.")
288        }
289
290        let name = &auto_cfg.architectures[0];
291
292        let tp = NormalLoaderType::from_causal_lm_name(name)?;
293
294        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
295
296        match tp {
297            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
298            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
299            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
300            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
301            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
302            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
303            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
304            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
305            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
306            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
307            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
308            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
309            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
310            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
311            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
312            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
313            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
314            NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
315        }
316    }
317}
318
319impl NormalModelLoader for AutoNormalLoader {
320    fn load(
321        &self,
322        config: &str,
323        vb: ShardedVarBuilder,
324        normal_loading_metadata: NormalLoadingMetadata,
325        attention_mechanism: AttentionImplementation,
326    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
328    }
329    fn load_xlora(
330        &self,
331        config: &str,
332        vb: ShardedVarBuilder,
333        lora_config: &[((String, String), LoraConfig)],
334        xlora_config: Option<XLoraConfig>,
335        xlora_ordering: Ordering,
336        normal_loading_metadata: NormalLoadingMetadata,
337        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
338    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339        Self::get_loader(config)?.load_xlora(
340            config,
341            vb,
342            lora_config,
343            xlora_config,
344            xlora_ordering,
345            normal_loading_metadata,
346            preload_adapters,
347        )
348    }
349    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
350        Self::get_loader(config)?.get_config_repr(config)
351    }
352    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
353        Self::get_loader(config)?.supports_paged_attention(config)
354    }
355    fn is_gptx(&self, config: &str) -> Result<bool> {
356        Self::get_loader(config)?.is_gptx(config)
357    }
358}
359
360impl IsqModelLoader for AutoNormalLoader {
361    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
362        Self::get_loader(config)?.immediate_isq_predicates(config)
363    }
364    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
365        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
366    }
367    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
368        Self::get_loader(config)?.isq_layer_regexes(config)
369    }
370    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
371        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
372    }
373}
374
375impl DeviceMappedModelLoader for AutoNormalLoader {
376    fn non_mapped_size_in_bytes(
377        &self,
378        config: &str,
379        dtype: DType,
380        weight_pack_factor: usize,
381        _matformer_config: Option<&MatformerSliceConfig>,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.non_mapped_size_in_bytes(
384            config,
385            dtype,
386            weight_pack_factor,
387            _matformer_config,
388        )
389    }
390    fn num_layers(&self, config: &str) -> Result<usize> {
391        Self::get_loader(config)?.num_layers(config)
392    }
393    fn layer_sizes_in_bytes(
394        &self,
395        config: &str,
396        dtype: DType,
397        weight_pack_factor: usize,
398        _matformer_config: Option<&MatformerSliceConfig>,
399    ) -> Result<Vec<usize>> {
400        Self::get_loader(config)?.layer_sizes_in_bytes(
401            config,
402            dtype,
403            weight_pack_factor,
404            _matformer_config,
405        )
406    }
407    fn mapped_max_act_size_elems(
408        &self,
409        config: &str,
410        params: &super::AutoDeviceMapParams,
411    ) -> Result<usize> {
412        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
413    }
414    fn non_mapped_max_act_size_elems(
415        &self,
416        _config: &str,
417        _params: &AutoDeviceMapParams,
418    ) -> Result<usize> {
419        Ok(0)
420    }
421    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
422        Self::get_loader(config)?.model_config(config)
423    }
424}
425
426// ======================== Mistral loader
427
428pub struct MistralLoader;
429
430impl NormalModelLoader for MistralLoader {
431    fn load(
432        &self,
433        config: &str,
434        vb: ShardedVarBuilder,
435        normal_loading_metadata: NormalLoadingMetadata,
436        attention_mechanism: AttentionImplementation,
437    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
438        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
439        Ok(Box::new(models::mistral::Model::new(
440            &cfg,
441            vb,
442            self.is_gptx(config)?,
443            normal_loading_metadata,
444            attention_mechanism,
445        )?))
446    }
447    fn load_xlora(
448        &self,
449        config: &str,
450        vb: ShardedVarBuilder,
451        lora_config: &[((String, String), LoraConfig)],
452        xlora_config: Option<XLoraConfig>,
453        xlora_ordering: Ordering,
454        normal_loading_metadata: NormalLoadingMetadata,
455        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
456    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
457        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
458        Ok(Box::new(xlora_models::XLoraMistral::new(
459            &cfg,
460            vb,
461            lora_config,
462            xlora_config,
463            xlora_ordering,
464            self.is_gptx(config)?,
465            normal_loading_metadata,
466            preload_adapters,
467        )?))
468    }
469    fn is_gptx(&self, _: &str) -> Result<bool> {
470        Ok(true)
471    }
472    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
473        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
474        Ok(Box::new(cfg))
475    }
476}
477
478impl IsqModelLoader for MistralLoader {
479    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
480        Ok(vec![
481            Regex::new(r"lm_head\.(weight|bias)$")?,
482            // Attention
483            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
484            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
485            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
486            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
487            // MLP
488            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
489            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
490            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
491        ])
492    }
493    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
494        self.isq_layer_regexes(config)
495    }
496}
497
498impl DeviceMappedModelLoader for MistralLoader {
499    fn mapped_max_act_size_elems(
500        &self,
501        config: &str,
502        params: &AutoDeviceMapParams,
503    ) -> Result<usize> {
504        let AutoDeviceMapParams::Text {
505            max_seq_len,
506            max_batch_size,
507        } = params
508        else {
509            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
510        };
511
512        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
513
514        Ok(
515            max_batch_size
516                * cfg.num_attention_heads
517                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
518        )
519    }
520    fn non_mapped_max_act_size_elems(
521        &self,
522        _config: &str,
523        _params: &AutoDeviceMapParams,
524    ) -> Result<usize> {
525        Ok(0)
526    }
527
528    fn non_mapped_size_in_bytes(
529        &self,
530        config: &str,
531        dtype: DType,
532        weight_pack_factor: usize,
533        _matformer_config: Option<&MatformerSliceConfig>,
534    ) -> Result<usize> {
535        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
536
537        let elems = {
538            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
539            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
540            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
541                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
542            } else {
543                0
544            };
545            let norm = cfg.hidden_size;
546            embed_tokens + lm_head + norm
547        };
548        Ok(elems * dtype.size_in_bytes())
549    }
550
551    fn layer_sizes_in_bytes(
552        &self,
553        config: &str,
554        dtype: DType,
555        weight_pack_factor: usize,
556        _matformer_config: Option<&MatformerSliceConfig>,
557    ) -> Result<Vec<usize>> {
558        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
559
560        let per_layer_elems = {
561            let input_layernorm = cfg.hidden_size;
562            let post_attention_layernorm = cfg.hidden_size;
563
564            let size_in = cfg.hidden_size;
565            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
566            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
567            let q_proj = size_in * size_q / weight_pack_factor;
568            let k_proj = size_in * size_kv / weight_pack_factor;
569            let v_proj = size_in * size_kv / weight_pack_factor;
570            let o_proj = size_q * size_in / weight_pack_factor;
571
572            let h_size = cfg.hidden_size;
573            let i_size = cfg.intermediate_size;
574            let gate_proj = h_size * i_size / weight_pack_factor;
575            let up_proj = h_size * i_size / weight_pack_factor;
576            let down_proj = i_size * h_size / weight_pack_factor;
577
578            input_layernorm
579                + post_attention_layernorm
580                + q_proj
581                + k_proj
582                + v_proj
583                + o_proj
584                + gate_proj
585                + up_proj
586                + down_proj
587        };
588        Ok(vec![
589            per_layer_elems * dtype.size_in_bytes();
590            cfg.num_hidden_layers
591        ])
592    }
593
594    fn num_layers(&self, config: &str) -> Result<usize> {
595        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
596        Ok(cfg.num_hidden_layers)
597    }
598
599    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
600        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
601
602        let cfg = ModelConfigMetadata {
603            max_seq_len: cfg.max_position_embeddings,
604            num_layers: cfg.num_hidden_layers,
605            hidden_size: cfg.hidden_size,
606            num_kv_heads: cfg.num_key_value_heads,
607            num_attn_heads: cfg.num_attention_heads,
608            sliding_window: cfg.sliding_window,
609            k_head_dim: cfg.head_dim(),
610            v_head_dim: cfg.head_dim(),
611        };
612
613        Ok(Box::new(cfg))
614    }
615}
616
617// ======================== Gemma loader
618
619/// [`NormalLoader`] for a Gemma model.
620///
621/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
622pub struct GemmaLoader;
623
624impl NormalModelLoader for GemmaLoader {
625    fn load(
626        &self,
627        config: &str,
628        vb: ShardedVarBuilder,
629        normal_loading_metadata: NormalLoadingMetadata,
630        attention_mechanism: AttentionImplementation,
631    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
632        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
633
634        Ok(Box::new(models::gemma::Model::new(
635            &cfg,
636            vb,
637            self.is_gptx(config)?,
638            normal_loading_metadata,
639            attention_mechanism,
640        )?))
641    }
642    fn load_xlora(
643        &self,
644        config: &str,
645        vb: ShardedVarBuilder,
646        lora_config: &[((String, String), LoraConfig)],
647        xlora_config: Option<XLoraConfig>,
648        xlora_ordering: Ordering,
649        normal_loading_metadata: NormalLoadingMetadata,
650        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
651    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
652        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
653
654        Ok(Box::new(xlora_models::XLoraGemma::new(
655            &cfg,
656            vb,
657            lora_config,
658            xlora_config,
659            xlora_ordering,
660            self.is_gptx(config)?,
661            normal_loading_metadata,
662            preload_adapters,
663        )?))
664    }
665    fn is_gptx(&self, _: &str) -> Result<bool> {
666        Ok(true)
667    }
668    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
669        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
670        Ok(Box::new(cfg))
671    }
672}
673
674impl IsqModelLoader for GemmaLoader {
675    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
676        Ok(vec![
677            Regex::new(r"lm_head\.(weight|bias)$")?,
678            // Attention
679            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
680            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
681            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
682            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
683            // MLP
684            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
685            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
686            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
687        ])
688    }
689    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
690        self.isq_layer_regexes(config)
691    }
692}
693
694impl DeviceMappedModelLoader for GemmaLoader {
695    fn mapped_max_act_size_elems(
696        &self,
697        config: &str,
698        params: &AutoDeviceMapParams,
699    ) -> Result<usize> {
700        let AutoDeviceMapParams::Text {
701            max_seq_len,
702            max_batch_size,
703        } = params
704        else {
705            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
706        };
707
708        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
709
710        Ok(
711            max_batch_size
712                * cfg.num_attention_heads
713                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
714        )
715    }
716    fn non_mapped_max_act_size_elems(
717        &self,
718        _config: &str,
719        _params: &AutoDeviceMapParams,
720    ) -> Result<usize> {
721        Ok(0)
722    }
723
724    fn non_mapped_size_in_bytes(
725        &self,
726        config: &str,
727        dtype: DType,
728        weight_pack_factor: usize,
729        _matformer_config: Option<&MatformerSliceConfig>,
730    ) -> Result<usize> {
731        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
732
733        let elems = {
734            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
735            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
736            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
737                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
738            } else {
739                0
740            };
741            let norm = cfg.hidden_size;
742            embed_tokens + lm_head + norm
743        };
744        Ok(elems * dtype.size_in_bytes())
745    }
746
747    fn layer_sizes_in_bytes(
748        &self,
749        config: &str,
750        dtype: DType,
751        weight_pack_factor: usize,
752        _matformer_config: Option<&MatformerSliceConfig>,
753    ) -> Result<Vec<usize>> {
754        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
755
756        let per_layer_elems = {
757            let input_layernorm = cfg.hidden_size;
758            let post_attention_layernorm = cfg.hidden_size;
759
760            let size_in = cfg.hidden_size;
761            let size_q = cfg.head_dim * cfg.num_attention_heads;
762            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
763            let q_proj =
764                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
765            let k_proj =
766                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
767            let v_proj =
768                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
769            let o_proj =
770                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
771
772            let h_size = cfg.hidden_size;
773            let i_size = cfg.intermediate_size;
774            let gate_proj = h_size * i_size / weight_pack_factor;
775            let up_proj = h_size * i_size / weight_pack_factor;
776            let down_proj = i_size * h_size / weight_pack_factor;
777
778            input_layernorm
779                + post_attention_layernorm
780                + q_proj
781                + k_proj
782                + v_proj
783                + o_proj
784                + gate_proj
785                + up_proj
786                + down_proj
787        };
788        Ok(vec![
789            per_layer_elems * dtype.size_in_bytes();
790            cfg.num_hidden_layers
791        ])
792    }
793
794    fn num_layers(&self, config: &str) -> Result<usize> {
795        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
796        Ok(cfg.num_hidden_layers)
797    }
798
799    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
800        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
801
802        let cfg = ModelConfigMetadata {
803            max_seq_len: cfg.max_position_embeddings,
804            num_layers: cfg.num_hidden_layers,
805            hidden_size: cfg.hidden_size,
806            num_kv_heads: cfg.num_key_value_heads,
807            num_attn_heads: cfg.num_attention_heads,
808            sliding_window: None,
809            k_head_dim: cfg.head_dim,
810            v_head_dim: cfg.head_dim,
811        };
812
813        Ok(Box::new(cfg))
814    }
815}
816
817// ======================== Llama loader
818
819/// [`NormalLoader`] for a Llama model.
820///
821/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
822pub struct LlamaLoader;
823
824impl NormalModelLoader for LlamaLoader {
825    fn load(
826        &self,
827        config: &str,
828        vb: ShardedVarBuilder,
829        normal_loading_metadata: NormalLoadingMetadata,
830        attention_mechanism: AttentionImplementation,
831    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
832        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
833
834        Ok(Box::new(models::llama::Llama::new(
835            &cfg,
836            vb,
837            self.is_gptx(config)?,
838            normal_loading_metadata,
839            attention_mechanism,
840        )?))
841    }
842    fn load_xlora(
843        &self,
844        config: &str,
845        vb: ShardedVarBuilder,
846        lora_config: &[((String, String), LoraConfig)],
847        xlora_config: Option<XLoraConfig>,
848        xlora_ordering: Ordering,
849        normal_loading_metadata: NormalLoadingMetadata,
850        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
851    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
852        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
853
854        Ok(Box::new(xlora_models::XLoraLlama::new(
855            &cfg,
856            vb,
857            lora_config,
858            xlora_config,
859            xlora_ordering,
860            self.is_gptx(config)?,
861            normal_loading_metadata,
862            preload_adapters,
863        )?))
864    }
865    fn is_gptx(&self, _: &str) -> Result<bool> {
866        Ok(true)
867    }
868    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
869        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
870        Ok(Box::new(cfg))
871    }
872}
873
874impl IsqModelLoader for LlamaLoader {
875    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
876        Ok(vec![
877            Regex::new(r"lm_head\.(weight|bias)$")?,
878            // Attention
879            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
880            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
881            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
882            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
883            // MLP
884            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
885            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
886            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
887        ])
888    }
889    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
890        self.isq_layer_regexes(config)
891    }
892}
893
894impl DeviceMappedModelLoader for LlamaLoader {
895    fn mapped_max_act_size_elems(
896        &self,
897        config: &str,
898        params: &AutoDeviceMapParams,
899    ) -> Result<usize> {
900        let AutoDeviceMapParams::Text {
901            max_seq_len,
902            max_batch_size,
903        } = params
904        else {
905            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
906        };
907
908        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
909
910        Ok(
911            max_batch_size
912                * cfg.num_attention_heads
913                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
914        )
915    }
916    fn non_mapped_max_act_size_elems(
917        &self,
918        _config: &str,
919        _params: &AutoDeviceMapParams,
920    ) -> Result<usize> {
921        Ok(0)
922    }
923
924    fn non_mapped_size_in_bytes(
925        &self,
926        config: &str,
927        dtype: DType,
928        weight_pack_factor: usize,
929        _matformer_config: Option<&MatformerSliceConfig>,
930    ) -> Result<usize> {
931        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
932
933        let elems = {
934            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
935            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
936            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
937                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
938            } else {
939                0
940            };
941            let norm = cfg.hidden_size;
942            embed_tokens + lm_head + norm
943        };
944        Ok(elems * dtype.size_in_bytes())
945    }
946
947    fn layer_sizes_in_bytes(
948        &self,
949        config: &str,
950        dtype: DType,
951        weight_pack_factor: usize,
952        _matformer_config: Option<&MatformerSliceConfig>,
953    ) -> Result<Vec<usize>> {
954        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956        let per_layer_elems = {
957            let input_layernorm = cfg.hidden_size;
958            let post_attention_layernorm = cfg.hidden_size;
959
960            let size_in = cfg.hidden_size;
961            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
962            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
963            let q_proj = size_in * size_q / weight_pack_factor;
964            let k_proj = size_in * size_kv / weight_pack_factor;
965            let v_proj = size_in * size_kv / weight_pack_factor;
966            let o_proj = size_q * size_in / weight_pack_factor;
967
968            let h_size = cfg.hidden_size;
969            let i_size = cfg.intermediate_size;
970            let gate_proj = h_size * i_size / weight_pack_factor;
971            let up_proj = h_size * i_size / weight_pack_factor;
972            let down_proj = i_size * h_size / weight_pack_factor;
973
974            input_layernorm
975                + post_attention_layernorm
976                + q_proj
977                + k_proj
978                + v_proj
979                + o_proj
980                + gate_proj
981                + up_proj
982                + down_proj
983        };
984        Ok(vec![
985            per_layer_elems * dtype.size_in_bytes();
986            cfg.num_hidden_layers
987        ])
988    }
989
990    fn num_layers(&self, config: &str) -> Result<usize> {
991        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
992
993        Ok(cfg.num_hidden_layers)
994    }
995    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
996        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
997
998        let cfg = ModelConfigMetadata {
999            max_seq_len: cfg.max_position_embeddings,
1000            num_layers: cfg.num_hidden_layers,
1001            hidden_size: cfg.hidden_size,
1002            num_kv_heads: cfg.num_key_value_heads,
1003            num_attn_heads: cfg.num_attention_heads,
1004            sliding_window: None,
1005            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1006            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1007        };
1008
1009        Ok(Box::new(cfg))
1010    }
1011}
1012
1013// ======================== Mixtral loader
1014
1015pub struct MixtralLoader;
1016
1017impl NormalModelLoader for MixtralLoader {
1018    fn load(
1019        &self,
1020        config: &str,
1021        vb: ShardedVarBuilder,
1022        normal_loading_metadata: NormalLoadingMetadata,
1023        attention_mechanism: AttentionImplementation,
1024    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1025        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1026
1027        Ok(Box::new(models::mixtral::Model::new(
1028            &cfg,
1029            vb,
1030            self.is_gptx(config)?,
1031            normal_loading_metadata,
1032            attention_mechanism,
1033        )?))
1034    }
1035    fn load_xlora(
1036        &self,
1037        config: &str,
1038        vb: ShardedVarBuilder,
1039        lora_config: &[((String, String), LoraConfig)],
1040        xlora_config: Option<XLoraConfig>,
1041        xlora_ordering: Ordering,
1042        normal_loading_metadata: NormalLoadingMetadata,
1043        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1044    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1045        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1046
1047        Ok(Box::new(xlora_models::XLoraMixtral::new(
1048            &cfg,
1049            vb,
1050            lora_config,
1051            xlora_config,
1052            xlora_ordering,
1053            self.is_gptx(config)?,
1054            normal_loading_metadata,
1055            preload_adapters,
1056        )?))
1057    }
1058    fn is_gptx(&self, _: &str) -> Result<bool> {
1059        Ok(true)
1060    }
1061    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1062        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1063
1064        Ok(Box::new(cfg))
1065    }
1066}
1067
1068impl IsqModelLoader for MixtralLoader {
1069    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1070        Ok(vec![
1071            Regex::new(r"lm_head\.(weight|bias)$")?,
1072            // Attention
1073            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1074            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1075            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1076            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1077            // Experts
1078            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1079            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1080            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1081            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1082        ])
1083    }
1084    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1085        self.isq_layer_regexes(config)
1086    }
1087}
1088
1089impl DeviceMappedModelLoader for MixtralLoader {
1090    fn mapped_max_act_size_elems(
1091        &self,
1092        config: &str,
1093        params: &AutoDeviceMapParams,
1094    ) -> Result<usize> {
1095        let AutoDeviceMapParams::Text {
1096            max_seq_len,
1097            max_batch_size,
1098        } = params
1099        else {
1100            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1101        };
1102
1103        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1104
1105        Ok(
1106            max_batch_size
1107                * cfg.num_attention_heads
1108                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1109        )
1110    }
1111    fn non_mapped_max_act_size_elems(
1112        &self,
1113        _config: &str,
1114        _params: &AutoDeviceMapParams,
1115    ) -> Result<usize> {
1116        Ok(0)
1117    }
1118
1119    fn non_mapped_size_in_bytes(
1120        &self,
1121        config: &str,
1122        dtype: DType,
1123        weight_pack_factor: usize,
1124        _matformer_config: Option<&MatformerSliceConfig>,
1125    ) -> Result<usize> {
1126        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128        let elems = {
1129            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1130            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1131            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1132                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1133            } else {
1134                0
1135            };
1136            let norm = cfg.hidden_size;
1137            embed_tokens + lm_head + norm
1138        };
1139        Ok(elems * dtype.size_in_bytes())
1140    }
1141
1142    fn layer_sizes_in_bytes(
1143        &self,
1144        config: &str,
1145        dtype: DType,
1146        weight_pack_factor: usize,
1147        _matformer_config: Option<&MatformerSliceConfig>,
1148    ) -> Result<Vec<usize>> {
1149        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1150
1151        let per_layer_elems = {
1152            let input_layernorm = cfg.hidden_size;
1153            let post_attention_layernorm = cfg.hidden_size;
1154
1155            let size_in = cfg.hidden_size;
1156            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1157            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1158            let q_proj = size_in * size_q / weight_pack_factor;
1159            let k_proj = size_in * size_kv / weight_pack_factor;
1160            let v_proj = size_in * size_kv / weight_pack_factor;
1161            let o_proj = size_q * size_in / weight_pack_factor;
1162
1163            let moe_block = {
1164                let gate = cfg.hidden_size * cfg.num_local_experts;
1165                // Assume quantizing weight pack factor
1166                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1167                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1168                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1169                gate + cfg.num_local_experts * w1
1170                    + cfg.num_local_experts * w2
1171                    + cfg.num_local_experts * w3
1172            };
1173
1174            input_layernorm
1175                + post_attention_layernorm
1176                + q_proj
1177                + k_proj
1178                + v_proj
1179                + o_proj
1180                + moe_block
1181        };
1182        Ok(vec![
1183            per_layer_elems * dtype.size_in_bytes();
1184            cfg.num_hidden_layers
1185        ])
1186    }
1187
1188    fn num_layers(&self, config: &str) -> Result<usize> {
1189        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1190
1191        Ok(cfg.num_hidden_layers)
1192    }
1193
1194    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1195        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1196
1197        let cfg = ModelConfigMetadata {
1198            max_seq_len: cfg.max_position_embeddings,
1199            num_layers: cfg.num_hidden_layers,
1200            hidden_size: cfg.hidden_size,
1201            num_kv_heads: cfg.num_key_value_heads,
1202            num_attn_heads: cfg.num_attention_heads,
1203            sliding_window: cfg.sliding_window,
1204            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1205            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1206        };
1207
1208        Ok(Box::new(cfg))
1209    }
1210}
1211
1212// ======================== Phi2 loader
1213
1214/// [`NormalLoader`] for a Phi 2 model.
1215///
1216/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1217pub struct Phi2Loader;
1218
1219impl NormalModelLoader for Phi2Loader {
1220    fn load(
1221        &self,
1222        config: &str,
1223        vb: ShardedVarBuilder,
1224        normal_loading_metadata: NormalLoadingMetadata,
1225        attention_mechanism: AttentionImplementation,
1226    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1227        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1228
1229        Ok(Box::new(models::phi2::Model::new(
1230            &cfg,
1231            vb,
1232            self.is_gptx(config)?,
1233            normal_loading_metadata,
1234            attention_mechanism,
1235        )?))
1236    }
1237    fn load_xlora(
1238        &self,
1239        config: &str,
1240        vb: ShardedVarBuilder,
1241        lora_config: &[((String, String), LoraConfig)],
1242        xlora_config: Option<XLoraConfig>,
1243        xlora_ordering: Ordering,
1244        normal_loading_metadata: NormalLoadingMetadata,
1245        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1246    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1247        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1248
1249        Ok(Box::new(xlora_models::XLoraPhi2::new(
1250            &cfg,
1251            vb,
1252            lora_config,
1253            xlora_config,
1254            xlora_ordering,
1255            self.is_gptx(config)?,
1256            normal_loading_metadata,
1257            preload_adapters,
1258        )?))
1259    }
1260    fn is_gptx(&self, _: &str) -> Result<bool> {
1261        Ok(true)
1262    }
1263    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1264        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1265
1266        Ok(Box::new(cfg))
1267    }
1268}
1269
1270impl IsqModelLoader for Phi2Loader {
1271    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1272        Ok(vec![
1273            Regex::new(r"lm_head\.(weight|bias)$")?,
1274            // Attention
1275            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1276            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1277            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1278            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1279            // MLP
1280            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1281            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1282        ])
1283    }
1284    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1285        self.isq_layer_regexes(config)
1286    }
1287}
1288
1289impl DeviceMappedModelLoader for Phi2Loader {
1290    fn mapped_max_act_size_elems(
1291        &self,
1292        config: &str,
1293        params: &AutoDeviceMapParams,
1294    ) -> Result<usize> {
1295        let AutoDeviceMapParams::Text {
1296            max_seq_len,
1297            max_batch_size,
1298        } = params
1299        else {
1300            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1301        };
1302
1303        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1304
1305        Ok(
1306            max_batch_size
1307                * cfg.num_attention_heads
1308                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1309        )
1310    }
1311    fn non_mapped_max_act_size_elems(
1312        &self,
1313        _config: &str,
1314        _params: &AutoDeviceMapParams,
1315    ) -> Result<usize> {
1316        Ok(0)
1317    }
1318
1319    fn non_mapped_size_in_bytes(
1320        &self,
1321        config: &str,
1322        dtype: DType,
1323        weight_pack_factor: usize,
1324        _matformer_config: Option<&MatformerSliceConfig>,
1325    ) -> Result<usize> {
1326        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1327
1328        let elems = {
1329            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1330            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1331            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1332                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1333            } else {
1334                0
1335            };
1336            let norm = cfg.hidden_size;
1337            embed_tokens + lm_head + norm
1338        };
1339        Ok(elems * dtype.size_in_bytes())
1340    }
1341
1342    fn layer_sizes_in_bytes(
1343        &self,
1344        config: &str,
1345        dtype: DType,
1346        weight_pack_factor: usize,
1347        _matformer_config: Option<&MatformerSliceConfig>,
1348    ) -> Result<Vec<usize>> {
1349        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1350
1351        let per_layer_elems = {
1352            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1353
1354            let size_in = cfg.hidden_size;
1355            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1356            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1357            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1358            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1359            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1360            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1361            let (q_norm, k_norm) = if cfg.qk_layernorm {
1362                (cfg.head_dim(), cfg.head_dim())
1363            } else {
1364                (0, 0)
1365            };
1366
1367            let h_size = cfg.hidden_size;
1368            let i_size = cfg.intermediate_size;
1369            let fc1 = h_size * i_size / weight_pack_factor;
1370            let fc2 = h_size * i_size / weight_pack_factor;
1371
1372            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1373        };
1374        Ok(vec![
1375            per_layer_elems * dtype.size_in_bytes();
1376            cfg.num_hidden_layers
1377        ])
1378    }
1379
1380    fn num_layers(&self, config: &str) -> Result<usize> {
1381        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1382
1383        Ok(cfg.num_hidden_layers)
1384    }
1385
1386    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1387        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1388
1389        let cfg = ModelConfigMetadata {
1390            max_seq_len: cfg.max_position_embeddings,
1391            num_layers: cfg.num_hidden_layers,
1392            hidden_size: cfg.hidden_size,
1393            num_kv_heads: cfg.num_key_value_heads(),
1394            num_attn_heads: cfg.num_attention_heads,
1395            sliding_window: None,
1396            k_head_dim: cfg.head_dim(),
1397            v_head_dim: cfg.head_dim(),
1398        };
1399
1400        Ok(Box::new(cfg))
1401    }
1402}
1403
1404// ======================== Phi3 loader
1405
1406/// [`NormalLoader`] for a Phi 3 model.
1407///
1408/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1409pub struct Phi3Loader;
1410
1411impl NormalModelLoader for Phi3Loader {
1412    fn load(
1413        &self,
1414        config: &str,
1415        vb: ShardedVarBuilder,
1416        normal_loading_metadata: NormalLoadingMetadata,
1417        attention_mechanism: AttentionImplementation,
1418    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1419        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1420
1421        Ok(Box::new(models::phi3::Model::new(
1422            &cfg,
1423            vb,
1424            self.is_gptx(config)?,
1425            normal_loading_metadata,
1426            attention_mechanism,
1427        )?))
1428    }
1429    fn load_xlora(
1430        &self,
1431        config: &str,
1432        vb: ShardedVarBuilder,
1433        lora_config: &[((String, String), LoraConfig)],
1434        xlora_config: Option<XLoraConfig>,
1435        xlora_ordering: Ordering,
1436        normal_loading_metadata: NormalLoadingMetadata,
1437        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1438    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1439        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1440
1441        Ok(Box::new(xlora_models::XLoraPhi3::new(
1442            &cfg,
1443            vb,
1444            lora_config,
1445            xlora_config,
1446            xlora_ordering,
1447            self.is_gptx(config)?,
1448            normal_loading_metadata,
1449            preload_adapters,
1450        )?))
1451    }
1452    fn is_gptx(&self, _: &str) -> Result<bool> {
1453        Ok(true)
1454    }
1455    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1456        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1457
1458        Ok(Box::new(cfg))
1459    }
1460}
1461
1462impl IsqModelLoader for Phi3Loader {
1463    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1464        Ok(vec![
1465            Regex::new(r"lm_head\.(weight|bias)$")?,
1466            // Attention
1467            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1468            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1469            // MLP
1470            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1471            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1472            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1473        ])
1474    }
1475    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1476        self.isq_layer_regexes(config)
1477    }
1478}
1479
1480impl DeviceMappedModelLoader for Phi3Loader {
1481    fn mapped_max_act_size_elems(
1482        &self,
1483        config: &str,
1484        params: &AutoDeviceMapParams,
1485    ) -> Result<usize> {
1486        let AutoDeviceMapParams::Text {
1487            max_seq_len,
1488            max_batch_size,
1489        } = params
1490        else {
1491            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1492        };
1493
1494        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1495
1496        Ok(
1497            max_batch_size
1498                * cfg.num_attention_heads
1499                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1500        )
1501    }
1502    fn non_mapped_max_act_size_elems(
1503        &self,
1504        _config: &str,
1505        _params: &AutoDeviceMapParams,
1506    ) -> Result<usize> {
1507        Ok(0)
1508    }
1509
1510    fn non_mapped_size_in_bytes(
1511        &self,
1512        config: &str,
1513        dtype: DType,
1514        weight_pack_factor: usize,
1515        _matformer_config: Option<&MatformerSliceConfig>,
1516    ) -> Result<usize> {
1517        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1518
1519        let elems = {
1520            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1521            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1522            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1523                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1524            } else {
1525                0
1526            };
1527            let norm = cfg.hidden_size;
1528            embed_tokens + lm_head + norm
1529        };
1530        Ok(elems * dtype.size_in_bytes())
1531    }
1532
1533    fn layer_sizes_in_bytes(
1534        &self,
1535        config: &str,
1536        dtype: DType,
1537        weight_pack_factor: usize,
1538        _matformer_config: Option<&MatformerSliceConfig>,
1539    ) -> Result<Vec<usize>> {
1540        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1541
1542        let per_layer_elems = {
1543            let input_layernorm = cfg.hidden_size;
1544            let post_attention_layernorm = cfg.hidden_size;
1545
1546            let size_in = cfg.hidden_size;
1547            let head_dim = cfg.head_dim();
1548            let op_size =
1549                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1550            let qkv_proj = size_in * op_size / weight_pack_factor;
1551            let o_proj =
1552                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1553
1554            let h_size = cfg.hidden_size;
1555            let i_size = cfg.intermediate_size;
1556            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1557            let down_proj = h_size * i_size / weight_pack_factor;
1558
1559            input_layernorm
1560                + post_attention_layernorm
1561                + qkv_proj
1562                + o_proj
1563                + gate_up_proj
1564                + down_proj
1565        };
1566        Ok(vec![
1567            per_layer_elems * dtype.size_in_bytes();
1568            cfg.num_hidden_layers
1569        ])
1570    }
1571
1572    fn num_layers(&self, config: &str) -> Result<usize> {
1573        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1574
1575        Ok(cfg.num_hidden_layers)
1576    }
1577
1578    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1579        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1580
1581        let cfg = ModelConfigMetadata {
1582            max_seq_len: cfg.max_position_embeddings,
1583            num_layers: cfg.num_hidden_layers,
1584            hidden_size: cfg.hidden_size,
1585            num_kv_heads: cfg.num_key_value_heads,
1586            num_attn_heads: cfg.num_attention_heads,
1587            sliding_window: cfg.sliding_window,
1588            k_head_dim: cfg.head_dim(),
1589            v_head_dim: cfg.head_dim(),
1590        };
1591
1592        Ok(Box::new(cfg))
1593    }
1594}
1595
1596// ======================== Qwen2 loader
1597
1598/// [`NormalLoader`] for a Qwen 2 model.
1599///
1600/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1601pub struct Qwen2Loader;
1602
1603impl NormalModelLoader for Qwen2Loader {
1604    fn load(
1605        &self,
1606        config: &str,
1607        vb: ShardedVarBuilder,
1608        normal_loading_metadata: NormalLoadingMetadata,
1609        attention_mechanism: AttentionImplementation,
1610    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1611        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1612
1613        Ok(Box::new(models::qwen2::Model::new(
1614            &cfg,
1615            vb,
1616            self.is_gptx(config)?,
1617            normal_loading_metadata,
1618            attention_mechanism,
1619        )?))
1620    }
1621    fn load_xlora(
1622        &self,
1623        _config: &str,
1624        _vb: ShardedVarBuilder,
1625        _lora_config: &[((String, String), LoraConfig)],
1626        _xlora_config: Option<XLoraConfig>,
1627        _xlora_ordering: Ordering,
1628        _normal_loading_metadata: NormalLoadingMetadata,
1629        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1630    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1631        todo!()
1632    }
1633    fn is_gptx(&self, _: &str) -> Result<bool> {
1634        Ok(true)
1635    }
1636    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1637        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1638
1639        Ok(Box::new(cfg))
1640    }
1641}
1642
1643impl IsqModelLoader for Qwen2Loader {
1644    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1645        Ok(vec![
1646            Regex::new(r"lm_head\.(weight|bias)$")?,
1647            // Attention
1648            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1649            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1650            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1651            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1652            // MLP
1653            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1654            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1655            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1656        ])
1657    }
1658    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1659        self.isq_layer_regexes(config)
1660    }
1661}
1662
1663impl DeviceMappedModelLoader for Qwen2Loader {
1664    fn mapped_max_act_size_elems(
1665        &self,
1666        config: &str,
1667        params: &AutoDeviceMapParams,
1668    ) -> Result<usize> {
1669        let AutoDeviceMapParams::Text {
1670            max_seq_len,
1671            max_batch_size,
1672        } = params
1673        else {
1674            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1675        };
1676
1677        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1678
1679        Ok(
1680            max_batch_size
1681                * cfg.num_attention_heads
1682                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1683        )
1684    }
1685    fn non_mapped_max_act_size_elems(
1686        &self,
1687        _config: &str,
1688        _params: &AutoDeviceMapParams,
1689    ) -> Result<usize> {
1690        Ok(0)
1691    }
1692
1693    fn non_mapped_size_in_bytes(
1694        &self,
1695        config: &str,
1696        dtype: DType,
1697        weight_pack_factor: usize,
1698        _matformer_config: Option<&MatformerSliceConfig>,
1699    ) -> Result<usize> {
1700        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1701
1702        let elems = {
1703            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1704            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1705            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1706                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1707            } else {
1708                0
1709            };
1710            let norm = cfg.hidden_size;
1711            embed_tokens + lm_head + norm
1712        };
1713        Ok(elems * dtype.size_in_bytes())
1714    }
1715
1716    fn layer_sizes_in_bytes(
1717        &self,
1718        config: &str,
1719        dtype: DType,
1720        weight_pack_factor: usize,
1721        _matformer_config: Option<&MatformerSliceConfig>,
1722    ) -> Result<Vec<usize>> {
1723        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1724
1725        let per_layer_elems = {
1726            let input_layernorm = cfg.hidden_size;
1727            let post_attention_layernorm = cfg.hidden_size;
1728
1729            let size_in = cfg.hidden_size;
1730            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1731            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1732            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1733            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1734            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1735            let o_proj = size_q * size_in / weight_pack_factor;
1736
1737            let h_size = cfg.hidden_size;
1738            let i_size = cfg.intermediate_size;
1739            let gate_proj = h_size * i_size / weight_pack_factor;
1740            let up_proj = h_size * i_size / weight_pack_factor;
1741            let down_proj = i_size * h_size / weight_pack_factor;
1742
1743            input_layernorm
1744                + post_attention_layernorm
1745                + q_proj
1746                + k_proj
1747                + v_proj
1748                + o_proj
1749                + gate_proj
1750                + up_proj
1751                + down_proj
1752        };
1753        Ok(vec![
1754            per_layer_elems * dtype.size_in_bytes();
1755            cfg.num_hidden_layers
1756        ])
1757    }
1758
1759    fn num_layers(&self, config: &str) -> Result<usize> {
1760        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1761
1762        Ok(cfg.num_hidden_layers)
1763    }
1764
1765    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1766        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1767
1768        let cfg = ModelConfigMetadata {
1769            max_seq_len: cfg.max_position_embeddings,
1770            num_layers: cfg.num_hidden_layers,
1771            hidden_size: cfg.hidden_size,
1772            num_kv_heads: cfg.num_key_value_heads,
1773            num_attn_heads: cfg.num_attention_heads,
1774            sliding_window: cfg.sliding_window,
1775            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1776            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1777        };
1778
1779        Ok(Box::new(cfg))
1780    }
1781}
1782
1783// ======================== Gemma2 loader
1784
1785/// [`NormalLoader`] for a Gemma2 model.
1786///
1787/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1788pub struct Gemma2Loader;
1789
1790impl NormalModelLoader for Gemma2Loader {
1791    fn load(
1792        &self,
1793        config: &str,
1794        vb: ShardedVarBuilder,
1795        normal_loading_metadata: NormalLoadingMetadata,
1796        attention_mechanism: AttentionImplementation,
1797    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1798        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1799
1800        Ok(Box::new(models::gemma2::Model::new(
1801            &cfg,
1802            vb,
1803            self.is_gptx(config)?,
1804            normal_loading_metadata,
1805            attention_mechanism,
1806        )?))
1807    }
1808    fn load_xlora(
1809        &self,
1810        config: &str,
1811        vb: ShardedVarBuilder,
1812        lora_config: &[((String, String), LoraConfig)],
1813        xlora_config: Option<XLoraConfig>,
1814        xlora_ordering: Ordering,
1815        normal_loading_metadata: NormalLoadingMetadata,
1816        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1817    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1818        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1819
1820        Ok(Box::new(xlora_models::XLoraGemma2::new(
1821            &cfg,
1822            vb,
1823            lora_config,
1824            xlora_config,
1825            xlora_ordering,
1826            self.is_gptx(config)?,
1827            normal_loading_metadata,
1828            preload_adapters,
1829        )?))
1830    }
1831    fn is_gptx(&self, _: &str) -> Result<bool> {
1832        Ok(true)
1833    }
1834    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1835        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1836
1837        Ok(Box::new(cfg))
1838    }
1839}
1840
1841impl IsqModelLoader for Gemma2Loader {
1842    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1843        Ok(vec![
1844            Regex::new(r"lm_head\.(weight|bias)$")?,
1845            // Attention
1846            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1847            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1848            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1849            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1850            // MLP
1851            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1852            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1853            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1854        ])
1855    }
1856    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1857        self.isq_layer_regexes(config)
1858    }
1859}
1860
1861impl DeviceMappedModelLoader for Gemma2Loader {
1862    fn mapped_max_act_size_elems(
1863        &self,
1864        config: &str,
1865        params: &AutoDeviceMapParams,
1866    ) -> Result<usize> {
1867        let AutoDeviceMapParams::Text {
1868            max_seq_len,
1869            max_batch_size,
1870        } = params
1871        else {
1872            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1873        };
1874
1875        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1876
1877        Ok(
1878            max_batch_size
1879                * cfg.num_attention_heads
1880                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1881        )
1882    }
1883    fn non_mapped_max_act_size_elems(
1884        &self,
1885        _config: &str,
1886        _params: &AutoDeviceMapParams,
1887    ) -> Result<usize> {
1888        Ok(0)
1889    }
1890
1891    fn non_mapped_size_in_bytes(
1892        &self,
1893        config: &str,
1894        dtype: DType,
1895        weight_pack_factor: usize,
1896        _matformer_config: Option<&MatformerSliceConfig>,
1897    ) -> Result<usize> {
1898        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1899
1900        let elems = {
1901            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1902            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1903            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1904                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1905            } else {
1906                0
1907            };
1908            let norm = cfg.hidden_size;
1909            embed_tokens + lm_head + norm
1910        };
1911        Ok(elems * dtype.size_in_bytes())
1912    }
1913
1914    fn layer_sizes_in_bytes(
1915        &self,
1916        config: &str,
1917        dtype: DType,
1918        weight_pack_factor: usize,
1919        _matformer_config: Option<&MatformerSliceConfig>,
1920    ) -> Result<Vec<usize>> {
1921        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1922
1923        let per_layer_elems = {
1924            let input_layernorm = cfg.hidden_size;
1925            let post_attention_layernorm = cfg.hidden_size;
1926
1927            let size_in = cfg.hidden_size;
1928            let size_q = cfg.head_dim * cfg.num_attention_heads;
1929            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1930            let q_proj =
1931                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1932            let k_proj =
1933                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1934            let v_proj =
1935                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1936            let o_proj =
1937                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1938
1939            let h_size = cfg.hidden_size;
1940            let i_size = cfg.intermediate_size;
1941            let gate_proj = h_size * i_size / weight_pack_factor;
1942            let up_proj = h_size * i_size / weight_pack_factor;
1943            let down_proj = i_size * h_size / weight_pack_factor;
1944
1945            input_layernorm
1946                + post_attention_layernorm
1947                + q_proj
1948                + k_proj
1949                + v_proj
1950                + o_proj
1951                + gate_proj
1952                + up_proj
1953                + down_proj
1954        };
1955        Ok(vec![
1956            per_layer_elems * dtype.size_in_bytes();
1957            cfg.num_hidden_layers
1958        ])
1959    }
1960
1961    fn num_layers(&self, config: &str) -> Result<usize> {
1962        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1963
1964        Ok(cfg.num_hidden_layers)
1965    }
1966    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1967        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1968
1969        let cfg = ModelConfigMetadata {
1970            max_seq_len: cfg.max_position_embeddings,
1971            num_layers: cfg.num_hidden_layers,
1972            hidden_size: cfg.hidden_size,
1973            num_kv_heads: cfg.num_key_value_heads,
1974            num_attn_heads: cfg.num_attention_heads,
1975            sliding_window: None, // None to be more forgiving, some do not
1976            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1977            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1978        };
1979
1980        Ok(Box::new(cfg))
1981    }
1982}
1983
1984// ======================== Starcoder2 loader
1985
1986/// [`NormalLoader`] for a Starcoder2 model.
1987///
1988/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1989pub struct Starcoder2Loader;
1990
1991impl NormalModelLoader for Starcoder2Loader {
1992    fn load(
1993        &self,
1994        config: &str,
1995        vb: ShardedVarBuilder,
1996        normal_loading_metadata: NormalLoadingMetadata,
1997        attention_mechanism: AttentionImplementation,
1998    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1999        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2000
2001        Ok(Box::new(models::starcoder2::Model::new(
2002            &cfg,
2003            vb,
2004            self.is_gptx(config)?,
2005            normal_loading_metadata,
2006            attention_mechanism,
2007        )?))
2008    }
2009    fn load_xlora(
2010        &self,
2011        config: &str,
2012        vb: ShardedVarBuilder,
2013        lora_config: &[((String, String), LoraConfig)],
2014        xlora_config: Option<XLoraConfig>,
2015        xlora_ordering: Ordering,
2016        normal_loading_metadata: NormalLoadingMetadata,
2017        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2018    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2019        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2020
2021        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2022            &cfg,
2023            vb,
2024            lora_config,
2025            xlora_config,
2026            xlora_ordering,
2027            self.is_gptx(config)?,
2028            normal_loading_metadata,
2029            preload_adapters,
2030        )?))
2031    }
2032    fn is_gptx(&self, _: &str) -> Result<bool> {
2033        Ok(true)
2034    }
2035    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2036        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2037
2038        Ok(Box::new(cfg))
2039    }
2040}
2041
2042impl IsqModelLoader for Starcoder2Loader {
2043    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2044        Ok(vec![
2045            Regex::new(r"lm_head\.(weight|bias)$")?,
2046            // Attention
2047            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2048            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2049            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2050            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2051            // MLP
2052            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2053            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2054        ])
2055    }
2056    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2057        self.isq_layer_regexes(config)
2058    }
2059}
2060
2061impl DeviceMappedModelLoader for Starcoder2Loader {
2062    fn mapped_max_act_size_elems(
2063        &self,
2064        config: &str,
2065        params: &AutoDeviceMapParams,
2066    ) -> Result<usize> {
2067        let AutoDeviceMapParams::Text {
2068            max_seq_len,
2069            max_batch_size,
2070        } = params
2071        else {
2072            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2073        };
2074
2075        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2076
2077        Ok(
2078            max_batch_size
2079                * cfg.num_attention_heads
2080                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2081        )
2082    }
2083    fn non_mapped_max_act_size_elems(
2084        &self,
2085        _config: &str,
2086        _params: &AutoDeviceMapParams,
2087    ) -> Result<usize> {
2088        Ok(0)
2089    }
2090
2091    fn non_mapped_size_in_bytes(
2092        &self,
2093        config: &str,
2094        dtype: DType,
2095        weight_pack_factor: usize,
2096        _matformer_config: Option<&MatformerSliceConfig>,
2097    ) -> Result<usize> {
2098        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2099
2100        let elems = {
2101            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2102            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2103            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2104                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2105            } else {
2106                0
2107            };
2108            let norm = cfg.hidden_size + cfg.hidden_size;
2109            embed_tokens + lm_head + norm
2110        };
2111        Ok(elems * dtype.size_in_bytes())
2112    }
2113
2114    fn layer_sizes_in_bytes(
2115        &self,
2116        config: &str,
2117        dtype: DType,
2118        weight_pack_factor: usize,
2119        _matformer_config: Option<&MatformerSliceConfig>,
2120    ) -> Result<Vec<usize>> {
2121        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2122
2123        let per_layer_elems = {
2124            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2125            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2126
2127            let size_in = cfg.hidden_size;
2128            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2129            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2130            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2131            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2132            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2133            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2134
2135            let h_size = cfg.hidden_size;
2136            let i_size = cfg.intermediate_size;
2137            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2138            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2139
2140            input_layernorm
2141                + post_attention_layernorm
2142                + q_proj
2143                + k_proj
2144                + v_proj
2145                + o_proj
2146                + fc1
2147                + fc2
2148        };
2149        Ok(vec![
2150            per_layer_elems * dtype.size_in_bytes();
2151            cfg.num_hidden_layers
2152        ])
2153    }
2154
2155    fn num_layers(&self, config: &str) -> Result<usize> {
2156        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2157
2158        Ok(cfg.num_hidden_layers)
2159    }
2160
2161    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2162        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2163
2164        let cfg = ModelConfigMetadata {
2165            max_seq_len: cfg.max_position_embeddings,
2166            num_layers: cfg.num_hidden_layers,
2167            hidden_size: cfg.hidden_size,
2168            num_kv_heads: cfg.num_key_value_heads,
2169            num_attn_heads: cfg.num_attention_heads,
2170            sliding_window: cfg.sliding_window,
2171            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2172            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2173        };
2174
2175        Ok(Box::new(cfg))
2176    }
2177}
2178
2179// ======================== Phi3 loader
2180
2181/// [`NormalLoader`] for a Phi 3.5 MoE model.
2182///
2183/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2184pub struct Phi3_5MoELoader;
2185
2186impl NormalModelLoader for Phi3_5MoELoader {
2187    fn load(
2188        &self,
2189        config: &str,
2190        vb: ShardedVarBuilder,
2191        normal_loading_metadata: NormalLoadingMetadata,
2192        attention_mechanism: AttentionImplementation,
2193    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2194        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2195
2196        Ok(Box::new(models::phi3_5_moe::Model::new(
2197            &cfg,
2198            vb,
2199            self.is_gptx(config)?,
2200            normal_loading_metadata,
2201            attention_mechanism,
2202        )?))
2203    }
2204    fn load_xlora(
2205        &self,
2206        config: &str,
2207        vb: ShardedVarBuilder,
2208        lora_config: &[((String, String), LoraConfig)],
2209        xlora_config: Option<XLoraConfig>,
2210        xlora_ordering: Ordering,
2211        normal_loading_metadata: NormalLoadingMetadata,
2212        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2213    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2214        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2215
2216        Ok(Box::new(xlora_models::XLoraPhi3::new(
2217            &cfg,
2218            vb,
2219            lora_config,
2220            xlora_config,
2221            xlora_ordering,
2222            self.is_gptx(config)?,
2223            normal_loading_metadata,
2224            preload_adapters,
2225        )?))
2226    }
2227    fn is_gptx(&self, _: &str) -> Result<bool> {
2228        Ok(true)
2229    }
2230    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2231        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233        Ok(Box::new(cfg))
2234    }
2235}
2236
2237impl IsqModelLoader for Phi3_5MoELoader {
2238    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2239        Ok(vec![
2240            Regex::new(r"lm_head\.(weight|bias)$")?,
2241            // Attention
2242            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2243            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2244            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2245            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2246            // MLP
2247            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2248            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2249            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2250        ])
2251    }
2252    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2253        self.isq_layer_regexes(config)
2254    }
2255
2256    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2257        Ok(vec![
2258            Regex::new(r"lm_head\.(weight|bias)$")?,
2259            // MLP
2260            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2261            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2262            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2263        ])
2264    }
2265    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2266        self.isq_layer_regexes_moqe(config)
2267    }
2268}
2269
2270impl DeviceMappedModelLoader for Phi3_5MoELoader {
2271    fn mapped_max_act_size_elems(
2272        &self,
2273        config: &str,
2274        params: &AutoDeviceMapParams,
2275    ) -> Result<usize> {
2276        let AutoDeviceMapParams::Text {
2277            max_seq_len,
2278            max_batch_size,
2279        } = params
2280        else {
2281            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2282        };
2283
2284        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2285
2286        Ok(
2287            max_batch_size
2288                * cfg.num_attention_heads
2289                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2290        )
2291    }
2292    fn non_mapped_max_act_size_elems(
2293        &self,
2294        _config: &str,
2295        _params: &AutoDeviceMapParams,
2296    ) -> Result<usize> {
2297        Ok(0)
2298    }
2299
2300    fn non_mapped_size_in_bytes(
2301        &self,
2302        config: &str,
2303        dtype: DType,
2304        weight_pack_factor: usize,
2305        _matformer_config: Option<&MatformerSliceConfig>,
2306    ) -> Result<usize> {
2307        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2308
2309        let elems = {
2310            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2311            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2312            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2313                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2314            } else {
2315                0
2316            };
2317            let norm = cfg.hidden_size;
2318            embed_tokens + lm_head + norm
2319        };
2320        Ok(elems * dtype.size_in_bytes())
2321    }
2322
2323    fn layer_sizes_in_bytes(
2324        &self,
2325        config: &str,
2326        dtype: DType,
2327        weight_pack_factor: usize,
2328        _matformer_config: Option<&MatformerSliceConfig>,
2329    ) -> Result<Vec<usize>> {
2330        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2331
2332        let per_layer_elems = {
2333            let input_layernorm = cfg.hidden_size;
2334            let post_attention_layernorm = cfg.hidden_size;
2335
2336            let size_in = cfg.hidden_size;
2337            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2338            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2339            let q_proj =
2340                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2341            let k_proj =
2342                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2343            let v_proj =
2344                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2345            let o_proj =
2346                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2347
2348            let moe_block = {
2349                let gate = cfg.hidden_size * cfg.num_local_experts;
2350                // Assume quantizing weight pack factor
2351                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2352                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2353                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2354                gate + cfg.num_local_experts * w1
2355                    + cfg.num_local_experts * w2
2356                    + cfg.num_local_experts * w3
2357            };
2358
2359            input_layernorm
2360                + post_attention_layernorm
2361                + q_proj
2362                + k_proj
2363                + v_proj
2364                + o_proj
2365                + moe_block
2366        };
2367        Ok(vec![
2368            per_layer_elems * dtype.size_in_bytes();
2369            cfg.num_hidden_layers
2370        ])
2371    }
2372
2373    fn num_layers(&self, config: &str) -> Result<usize> {
2374        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2375
2376        Ok(cfg.num_hidden_layers)
2377    }
2378
2379    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2380        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2381
2382        let cfg = ModelConfigMetadata {
2383            max_seq_len: cfg.max_position_embeddings,
2384            num_layers: cfg.num_hidden_layers,
2385            hidden_size: cfg.hidden_size,
2386            num_kv_heads: cfg.num_key_value_heads,
2387            num_attn_heads: cfg.num_attention_heads,
2388            sliding_window: cfg.sliding_window,
2389            k_head_dim: cfg.head_dim(),
2390            v_head_dim: cfg.head_dim(),
2391        };
2392
2393        Ok(Box::new(cfg))
2394    }
2395}
2396
2397/// [`NormalLoader`] for a DeepSeekV2 model.
2398///
2399/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2400pub struct DeepSeekV2Loader;
2401
2402impl NormalModelLoader for DeepSeekV2Loader {
2403    fn load(
2404        &self,
2405        config: &str,
2406        vb: ShardedVarBuilder,
2407        normal_loading_metadata: NormalLoadingMetadata,
2408        attention_mechanism: AttentionImplementation,
2409    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2410        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2411
2412        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2413            &cfg,
2414            vb,
2415            self.is_gptx(config)?,
2416            normal_loading_metadata,
2417            attention_mechanism,
2418        )?))
2419    }
2420    fn load_xlora(
2421        &self,
2422        _config: &str,
2423        _vb: ShardedVarBuilder,
2424        _lora_config: &[((String, String), LoraConfig)],
2425        _xlora_config: Option<XLoraConfig>,
2426        _xlora_ordering: Ordering,
2427        _normal_loading_metadata: NormalLoadingMetadata,
2428        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2429    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2430        todo!()
2431    }
2432    fn is_gptx(&self, _: &str) -> Result<bool> {
2433        Ok(true)
2434    }
2435    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2436        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2437        Ok(Box::new(cfg))
2438    }
2439}
2440
2441impl IsqModelLoader for DeepSeekV2Loader {
2442    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2443        let mut data = vec![
2444            Regex::new(r"lm_head\.(weight|bias)$")?,
2445            // Attention
2446            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2447            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2448            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2449        ];
2450        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2451        if cfg.q_lora_rank.is_some() {
2452            data.extend(vec![
2453                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2454                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2455            ]);
2456        } else {
2457            data.push(Regex::new(
2458                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2459            )?);
2460        }
2461        for layer_idx in 0..cfg.num_hidden_layers {
2462            if cfg.n_routed_experts.is_some()
2463                && layer_idx >= cfg.first_k_dense_replace
2464                && layer_idx % cfg.moe_layer_freq == 0
2465            {
2466                for i in 0..cfg.n_routed_experts.unwrap() {
2467                    data.extend(vec![
2468                        Regex::new(&format!(
2469                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2470                        ))?,
2471                        Regex::new(&format!(
2472                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2473                        ))?,
2474                        Regex::new(&format!(
2475                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2476                        ))?,
2477                    ]);
2478                }
2479                if cfg.n_shared_experts.is_some() {
2480                    data.extend(vec![
2481                        Regex::new(&format!(
2482                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2483                        ))?,
2484                        Regex::new(&format!(
2485                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2486                        ))?,
2487                        Regex::new(&format!(
2488                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2489                        ))?,
2490                    ]);
2491                }
2492            } else {
2493                data.extend(vec![
2494                    Regex::new(&format!(
2495                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2496                    ))?,
2497                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2498                    Regex::new(&format!(
2499                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2500                    ))?,
2501                ]);
2502            };
2503        }
2504        Ok(data)
2505    }
2506    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2507        self.isq_layer_regexes(config)
2508    }
2509
2510    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2511        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2512        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2513        for layer_idx in 0..cfg.num_hidden_layers {
2514            if cfg.n_routed_experts.is_some()
2515                && layer_idx >= cfg.first_k_dense_replace
2516                && layer_idx % cfg.moe_layer_freq == 0
2517            {
2518                for i in 0..cfg.n_routed_experts.unwrap() {
2519                    data.extend(vec![
2520                        Regex::new(&format!(
2521                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2522                        ))?,
2523                        Regex::new(&format!(
2524                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2525                        ))?,
2526                        Regex::new(&format!(
2527                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2528                        ))?,
2529                    ]);
2530                }
2531                if cfg.n_shared_experts.is_some() {
2532                    data.extend(vec![
2533                        Regex::new(&format!(
2534                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2535                        ))?,
2536                        Regex::new(&format!(
2537                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2538                        ))?,
2539                        Regex::new(&format!(
2540                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2541                        ))?,
2542                    ]);
2543                }
2544            } else {
2545                data.extend(vec![
2546                    Regex::new(&format!(
2547                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2548                    ))?,
2549                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2550                    Regex::new(&format!(
2551                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2552                    ))?,
2553                ]);
2554            };
2555        }
2556        Ok(data)
2557    }
2558    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2559        self.isq_layer_regexes_moqe(config)
2560    }
2561}
2562
2563impl DeviceMappedModelLoader for DeepSeekV2Loader {
2564    fn mapped_max_act_size_elems(
2565        &self,
2566        config: &str,
2567        params: &AutoDeviceMapParams,
2568    ) -> Result<usize> {
2569        let AutoDeviceMapParams::Text {
2570            max_seq_len,
2571            max_batch_size,
2572        } = params
2573        else {
2574            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2575        };
2576
2577        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2578
2579        Ok(
2580            max_batch_size
2581                * cfg.num_attention_heads
2582                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2583        )
2584    }
2585    fn non_mapped_max_act_size_elems(
2586        &self,
2587        _config: &str,
2588        _params: &AutoDeviceMapParams,
2589    ) -> Result<usize> {
2590        Ok(0)
2591    }
2592
2593    fn non_mapped_size_in_bytes(
2594        &self,
2595        config: &str,
2596        dtype: DType,
2597        weight_pack_factor: usize,
2598        _matformer_config: Option<&MatformerSliceConfig>,
2599    ) -> Result<usize> {
2600        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2601        let elems = {
2602            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2603            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2604            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2605                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2606            } else {
2607                0
2608            };
2609            let norm = cfg.hidden_size;
2610            embed_tokens + lm_head + norm
2611        };
2612        Ok(elems * dtype.size_in_bytes())
2613    }
2614
2615    fn layer_sizes_in_bytes(
2616        &self,
2617        config: &str,
2618        dtype: DType,
2619        weight_pack_factor: usize,
2620        _matformer_config: Option<&MatformerSliceConfig>,
2621    ) -> Result<Vec<usize>> {
2622        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2623        let mut per_layer_elems = Vec::new();
2624
2625        for layer_idx in 0..cfg.num_hidden_layers {
2626            let input_layernorm = cfg.hidden_size;
2627            let post_attention_layernorm = cfg.hidden_size;
2628
2629            let q_proj = match cfg.q_lora_rank {
2630                Some(lora_rank) => {
2631                    let a = cfg.hidden_size * lora_rank;
2632                    let norm = lora_rank;
2633                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2634                    a + norm + b
2635                }
2636                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2637            };
2638            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2639                / weight_pack_factor
2640                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2641            let kv_a_layernorm = cfg.kv_lora_rank;
2642            let kv_b_proj = cfg.kv_lora_rank
2643                * cfg.num_attention_heads
2644                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2645                / weight_pack_factor;
2646            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2647                / weight_pack_factor
2648                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2649
2650            let moe_block = {
2651                let mut sum = 0;
2652                if cfg.n_routed_experts.is_some()
2653                    && layer_idx >= cfg.first_k_dense_replace
2654                    && layer_idx % cfg.moe_layer_freq == 0
2655                {
2656                    let h_size = cfg.hidden_size;
2657                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2658                        * cfg.n_routed_experts.unwrap();
2659                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2660                        * cfg.n_routed_experts.unwrap();
2661                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2662                        * cfg.n_routed_experts.unwrap();
2663                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2664                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2665                            / weight_pack_factor;
2666                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2667                            / weight_pack_factor;
2668                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2669                            / weight_pack_factor;
2670                        gate_proj + up_proj + down_proj
2671                    } else {
2672                        0
2673                    };
2674                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2675                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2676                } else {
2677                    let h_size = cfg.hidden_size;
2678                    let i_size = cfg.intermediate_size;
2679                    let gate_proj = h_size * i_size / weight_pack_factor;
2680                    let up_proj = h_size * i_size / weight_pack_factor;
2681                    let down_proj = i_size * h_size / weight_pack_factor;
2682                    sum += gate_proj + up_proj + down_proj;
2683                }
2684                sum
2685            };
2686
2687            per_layer_elems.push(
2688                input_layernorm
2689                    + post_attention_layernorm
2690                    + q_proj
2691                    + kv_a_layernorm
2692                    + kv_a_proj_with_mqa
2693                    + kv_b_proj
2694                    + o_proj
2695                    + moe_block,
2696            );
2697        }
2698
2699        Ok(per_layer_elems
2700            .into_iter()
2701            .map(|x| x * dtype.size_in_bytes())
2702            .collect())
2703    }
2704
2705    fn num_layers(&self, config: &str) -> Result<usize> {
2706        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2707        Ok(cfg.num_hidden_layers)
2708    }
2709
2710    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2711        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2712
2713        let cfg = ModelConfigMetadata {
2714            max_seq_len: cfg.max_position_embeddings,
2715            num_layers: cfg.num_hidden_layers,
2716            hidden_size: cfg.hidden_size,
2717            num_kv_heads: cfg.num_attention_heads,
2718            num_attn_heads: cfg.num_attention_heads,
2719            sliding_window: None,
2720            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2721            v_head_dim: cfg.v_head_dim,
2722        };
2723
2724        Ok(Box::new(cfg))
2725    }
2726}
2727
2728/// [`NormalLoader`] for a DeepSeekV3 model.
2729///
2730/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2731pub struct DeepSeekV3Loader;
2732
2733impl NormalModelLoader for DeepSeekV3Loader {
2734    fn load(
2735        &self,
2736        config: &str,
2737        vb: ShardedVarBuilder,
2738        normal_loading_metadata: NormalLoadingMetadata,
2739        attention_mechanism: AttentionImplementation,
2740    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2741        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2742        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2743            &cfg,
2744            vb,
2745            self.is_gptx(config)?,
2746            normal_loading_metadata,
2747            attention_mechanism,
2748        )?))
2749    }
2750    fn load_xlora(
2751        &self,
2752        _config: &str,
2753        _vb: ShardedVarBuilder,
2754        _lora_config: &[((String, String), LoraConfig)],
2755        _xlora_config: Option<XLoraConfig>,
2756        _xlora_ordering: Ordering,
2757        _normal_loading_metadata: NormalLoadingMetadata,
2758        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2759    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2760        todo!()
2761    }
2762    fn is_gptx(&self, _: &str) -> Result<bool> {
2763        Ok(true)
2764    }
2765    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2766        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2767        Ok(Box::new(cfg))
2768    }
2769}
2770
2771impl IsqModelLoader for DeepSeekV3Loader {
2772    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2773        let mut data = vec![
2774            Regex::new(r"lm_head\.(weight|bias)$")?,
2775            // Attention
2776            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2777            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2778            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2779        ];
2780        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2781        if cfg.q_lora_rank.is_some() {
2782            data.extend(vec![
2783                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2784                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2785            ]);
2786        } else {
2787            data.push(Regex::new(
2788                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2789            )?);
2790        }
2791        for layer_idx in 0..cfg.num_hidden_layers {
2792            if cfg.n_routed_experts.is_some()
2793                && layer_idx >= cfg.first_k_dense_replace
2794                && layer_idx % cfg.moe_layer_freq == 0
2795            {
2796                for i in 0..cfg.n_routed_experts.unwrap() {
2797                    data.extend(vec![
2798                        Regex::new(&format!(
2799                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2800                        ))?,
2801                        Regex::new(&format!(
2802                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2803                        ))?,
2804                        Regex::new(&format!(
2805                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2806                        ))?,
2807                    ]);
2808                }
2809                if cfg.n_shared_experts.is_some() {
2810                    data.extend(vec![
2811                        Regex::new(&format!(
2812                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2813                        ))?,
2814                        Regex::new(&format!(
2815                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2816                        ))?,
2817                        Regex::new(&format!(
2818                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2819                        ))?,
2820                    ]);
2821                }
2822            } else {
2823                data.extend(vec![
2824                    Regex::new(&format!(
2825                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2826                    ))?,
2827                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2828                    Regex::new(&format!(
2829                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2830                    ))?,
2831                ]);
2832            };
2833        }
2834        Ok(data)
2835    }
2836    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2837        self.isq_layer_regexes(config)
2838    }
2839
2840    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2841        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2842        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2843        for layer_idx in 0..cfg.num_hidden_layers {
2844            if cfg.n_routed_experts.is_some()
2845                && layer_idx >= cfg.first_k_dense_replace
2846                && layer_idx % cfg.moe_layer_freq == 0
2847            {
2848                for i in 0..cfg.n_routed_experts.unwrap() {
2849                    data.extend(vec![
2850                        Regex::new(&format!(
2851                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2852                        ))?,
2853                        Regex::new(&format!(
2854                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2855                        ))?,
2856                        Regex::new(&format!(
2857                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2858                        ))?,
2859                    ]);
2860                }
2861                if cfg.n_shared_experts.is_some() {
2862                    data.extend(vec![
2863                        Regex::new(&format!(
2864                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2865                        ))?,
2866                        Regex::new(&format!(
2867                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2868                        ))?,
2869                        Regex::new(&format!(
2870                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2871                        ))?,
2872                    ]);
2873                }
2874            } else {
2875                data.extend(vec![
2876                    Regex::new(&format!(
2877                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2878                    ))?,
2879                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2880                    Regex::new(&format!(
2881                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2882                    ))?,
2883                ]);
2884            };
2885        }
2886        Ok(data)
2887    }
2888    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2889        self.isq_layer_regexes_moqe(config)
2890    }
2891}
2892
2893impl DeviceMappedModelLoader for DeepSeekV3Loader {
2894    fn mapped_max_act_size_elems(
2895        &self,
2896        config: &str,
2897        params: &AutoDeviceMapParams,
2898    ) -> Result<usize> {
2899        let AutoDeviceMapParams::Text {
2900            max_seq_len,
2901            max_batch_size,
2902        } = params
2903        else {
2904            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2905        };
2906
2907        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2908
2909        Ok(
2910            max_batch_size
2911                * cfg.num_attention_heads
2912                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2913        )
2914    }
2915    fn non_mapped_max_act_size_elems(
2916        &self,
2917        _config: &str,
2918        _params: &AutoDeviceMapParams,
2919    ) -> Result<usize> {
2920        Ok(0)
2921    }
2922
2923    fn non_mapped_size_in_bytes(
2924        &self,
2925        config: &str,
2926        dtype: DType,
2927        weight_pack_factor: usize,
2928        _matformer_config: Option<&MatformerSliceConfig>,
2929    ) -> Result<usize> {
2930        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2931        let elems = {
2932            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2933            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2934            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2935                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2936            } else {
2937                0
2938            };
2939            let norm = cfg.hidden_size;
2940            embed_tokens + lm_head + norm
2941        };
2942        Ok(elems * dtype.size_in_bytes())
2943    }
2944
2945    fn layer_sizes_in_bytes(
2946        &self,
2947        config: &str,
2948        dtype: DType,
2949        weight_pack_factor: usize,
2950        _matformer_config: Option<&MatformerSliceConfig>,
2951    ) -> Result<Vec<usize>> {
2952        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2953        let mut per_layer_elems = Vec::new();
2954
2955        for layer_idx in 0..cfg.num_hidden_layers {
2956            let input_layernorm = cfg.hidden_size;
2957            let post_attention_layernorm = cfg.hidden_size;
2958
2959            let q_proj = match cfg.q_lora_rank {
2960                Some(lora_rank) => {
2961                    let a = cfg.hidden_size * lora_rank;
2962                    let norm = lora_rank;
2963                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2964                    a + norm + b
2965                }
2966                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2967            };
2968            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2969                / weight_pack_factor
2970                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2971            let kv_a_layernorm = cfg.kv_lora_rank;
2972            let kv_b_proj = cfg.kv_lora_rank
2973                * cfg.num_attention_heads
2974                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2975                / weight_pack_factor;
2976            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2977                / weight_pack_factor
2978                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2979
2980            let moe_block = {
2981                let mut sum = 0;
2982                if cfg.n_routed_experts.is_some()
2983                    && layer_idx >= cfg.first_k_dense_replace
2984                    && layer_idx % cfg.moe_layer_freq == 0
2985                {
2986                    let h_size = cfg.hidden_size;
2987                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2988                        * cfg.n_routed_experts.unwrap();
2989                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2990                        * cfg.n_routed_experts.unwrap();
2991                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2992                        * cfg.n_routed_experts.unwrap();
2993                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2994                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2995                            / weight_pack_factor;
2996                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2997                            / weight_pack_factor;
2998                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2999                            / weight_pack_factor;
3000                        gate_proj + up_proj + down_proj
3001                    } else {
3002                        0
3003                    };
3004                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3005                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3006                } else {
3007                    let h_size = cfg.hidden_size;
3008                    let i_size = cfg.intermediate_size;
3009                    let gate_proj = h_size * i_size / weight_pack_factor;
3010                    let up_proj = h_size * i_size / weight_pack_factor;
3011                    let down_proj = i_size * h_size / weight_pack_factor;
3012                    sum += gate_proj + up_proj + down_proj;
3013                }
3014                sum
3015            };
3016
3017            per_layer_elems.push(
3018                input_layernorm
3019                    + post_attention_layernorm
3020                    + q_proj
3021                    + kv_a_layernorm
3022                    + kv_a_proj_with_mqa
3023                    + kv_b_proj
3024                    + o_proj
3025                    + moe_block,
3026            );
3027        }
3028
3029        Ok(per_layer_elems
3030            .into_iter()
3031            .map(|x| x * dtype.size_in_bytes())
3032            .collect())
3033    }
3034
3035    fn num_layers(&self, config: &str) -> Result<usize> {
3036        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3037        Ok(cfg.num_hidden_layers)
3038    }
3039
3040    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3041        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3042
3043        let cfg = ModelConfigMetadata {
3044            max_seq_len: cfg.max_position_embeddings,
3045            num_layers: cfg.num_hidden_layers,
3046            hidden_size: cfg.hidden_size,
3047            num_kv_heads: cfg.num_attention_heads,
3048            num_attn_heads: cfg.num_attention_heads,
3049            sliding_window: None,
3050            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3051            v_head_dim: cfg.v_head_dim,
3052        };
3053
3054        Ok(Box::new(cfg))
3055    }
3056}
3057
3058/// [`NormalLoader`] for a Qwen 3 model.
3059///
3060/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3061pub struct Qwen3Loader;
3062
3063impl NormalModelLoader for Qwen3Loader {
3064    fn load(
3065        &self,
3066        config: &str,
3067        vb: ShardedVarBuilder,
3068        normal_loading_metadata: NormalLoadingMetadata,
3069        attention_mechanism: AttentionImplementation,
3070    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3071        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3072
3073        Ok(Box::new(models::qwen3::Model::new(
3074            &cfg,
3075            vb,
3076            self.is_gptx(config)?,
3077            normal_loading_metadata,
3078            attention_mechanism,
3079        )?))
3080    }
3081    fn load_xlora(
3082        &self,
3083        _config: &str,
3084        _vb: ShardedVarBuilder,
3085        _lora_config: &[((String, String), LoraConfig)],
3086        _xlora_config: Option<XLoraConfig>,
3087        _xlora_ordering: Ordering,
3088        _normal_loading_metadata: NormalLoadingMetadata,
3089        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3090    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3091        todo!()
3092    }
3093    fn is_gptx(&self, _: &str) -> Result<bool> {
3094        Ok(true)
3095    }
3096    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3097        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3098
3099        Ok(Box::new(cfg))
3100    }
3101}
3102
3103impl IsqModelLoader for Qwen3Loader {
3104    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3105        Ok(vec![
3106            Regex::new(r"lm_head\.(weight|bias)$")?,
3107            // Attention
3108            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3109            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3110            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3111            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3112            // MLP
3113            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3114            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3115            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3116        ])
3117    }
3118    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3119        self.isq_layer_regexes(config)
3120    }
3121}
3122
3123impl DeviceMappedModelLoader for Qwen3Loader {
3124    fn mapped_max_act_size_elems(
3125        &self,
3126        config: &str,
3127        params: &AutoDeviceMapParams,
3128    ) -> Result<usize> {
3129        let AutoDeviceMapParams::Text {
3130            max_seq_len,
3131            max_batch_size,
3132        } = params
3133        else {
3134            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3135        };
3136
3137        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3138
3139        Ok(
3140            max_batch_size
3141                * cfg.num_attention_heads
3142                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3143        )
3144    }
3145    fn non_mapped_max_act_size_elems(
3146        &self,
3147        _config: &str,
3148        _params: &AutoDeviceMapParams,
3149    ) -> Result<usize> {
3150        Ok(0)
3151    }
3152
3153    fn non_mapped_size_in_bytes(
3154        &self,
3155        config: &str,
3156        dtype: DType,
3157        weight_pack_factor: usize,
3158        _matformer_config: Option<&MatformerSliceConfig>,
3159    ) -> Result<usize> {
3160        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3161        let elems = {
3162            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3163            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3164            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3165                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3166            } else {
3167                0
3168            };
3169            let norm = cfg.hidden_size;
3170            embed_tokens + lm_head + norm
3171        };
3172        Ok(elems * dtype.size_in_bytes())
3173    }
3174
3175    fn layer_sizes_in_bytes(
3176        &self,
3177        config: &str,
3178        dtype: DType,
3179        weight_pack_factor: usize,
3180        _matformer_config: Option<&MatformerSliceConfig>,
3181    ) -> Result<Vec<usize>> {
3182        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3183        let per_layer_elems = {
3184            let input_layernorm = cfg.hidden_size;
3185            let post_attention_layernorm = cfg.hidden_size;
3186
3187            let size_in = cfg.hidden_size;
3188            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3189            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3190            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3191            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3192            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3193            let o_proj = size_q * size_in / weight_pack_factor;
3194
3195            let h_size = cfg.hidden_size;
3196            let i_size = cfg.intermediate_size;
3197            let gate_proj = h_size * i_size / weight_pack_factor;
3198            let up_proj = h_size * i_size / weight_pack_factor;
3199            let down_proj = i_size * h_size / weight_pack_factor;
3200
3201            let q_norm = cfg.head_dim();
3202            let k_norm = cfg.head_dim();
3203
3204            input_layernorm
3205                + post_attention_layernorm
3206                + q_proj
3207                + k_proj
3208                + v_proj
3209                + o_proj
3210                + gate_proj
3211                + up_proj
3212                + down_proj
3213                + q_norm
3214                + k_norm
3215        };
3216        Ok(vec![
3217            per_layer_elems * dtype.size_in_bytes();
3218            cfg.num_hidden_layers
3219        ])
3220    }
3221
3222    fn num_layers(&self, config: &str) -> Result<usize> {
3223        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3224        Ok(cfg.num_hidden_layers)
3225    }
3226
3227    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3228        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3229
3230        let cfg = ModelConfigMetadata {
3231            max_seq_len: cfg.max_position_embeddings,
3232            num_layers: cfg.num_hidden_layers,
3233            hidden_size: cfg.hidden_size,
3234            num_kv_heads: cfg.num_key_value_heads,
3235            num_attn_heads: cfg.num_attention_heads,
3236            sliding_window: cfg.sliding_window,
3237            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3238            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3239        };
3240
3241        Ok(Box::new(cfg))
3242    }
3243}
3244
3245/// [`NormalLoader`] for a GLM 4 model.
3246///
3247/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3248pub struct GLM4Loader;
3249
3250impl NormalModelLoader for GLM4Loader {
3251    fn load(
3252        &self,
3253        config: &str,
3254        vb: ShardedVarBuilder,
3255        normal_loading_metadata: NormalLoadingMetadata,
3256        attention_mechanism: AttentionImplementation,
3257    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3258        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3259
3260        Ok(Box::new(models::glm4::Model::new(
3261            &cfg,
3262            vb,
3263            self.is_gptx(config)?,
3264            normal_loading_metadata,
3265            attention_mechanism,
3266        )?))
3267    }
3268    fn load_xlora(
3269        &self,
3270        _config: &str,
3271        _vb: ShardedVarBuilder,
3272        _lora_config: &[((String, String), LoraConfig)],
3273        _xlora_config: Option<XLoraConfig>,
3274        _xlora_ordering: Ordering,
3275        _normal_loading_metadata: NormalLoadingMetadata,
3276        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3277    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3278        todo!()
3279    }
3280    fn is_gptx(&self, _: &str) -> Result<bool> {
3281        Ok(true)
3282    }
3283    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3284        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3285
3286        Ok(Box::new(cfg))
3287    }
3288}
3289
3290impl IsqModelLoader for GLM4Loader {
3291    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3292        Ok(vec![
3293            Regex::new(r"lm_head\.(weight|bias)$")?,
3294            // Attention
3295            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3296            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3297            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3298            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3299            // MLP
3300            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3301            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3302            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3303        ])
3304    }
3305    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3306        self.isq_layer_regexes(config)
3307    }
3308}
3309
3310impl DeviceMappedModelLoader for GLM4Loader {
3311    fn mapped_max_act_size_elems(
3312        &self,
3313        config: &str,
3314        params: &AutoDeviceMapParams,
3315    ) -> Result<usize> {
3316        let AutoDeviceMapParams::Text {
3317            max_seq_len,
3318            max_batch_size,
3319        } = params
3320        else {
3321            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3322        };
3323
3324        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3325
3326        Ok(
3327            max_batch_size
3328                * cfg.num_attention_heads
3329                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3330        )
3331    }
3332    fn non_mapped_max_act_size_elems(
3333        &self,
3334        _config: &str,
3335        _params: &AutoDeviceMapParams,
3336    ) -> Result<usize> {
3337        Ok(0)
3338    }
3339
3340    fn non_mapped_size_in_bytes(
3341        &self,
3342        config: &str,
3343        dtype: DType,
3344        weight_pack_factor: usize,
3345        _matformer_config: Option<&MatformerSliceConfig>,
3346    ) -> Result<usize> {
3347        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3348        let elems = {
3349            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3350            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3351            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3352                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3353            } else {
3354                0
3355            };
3356            let norm = cfg.hidden_size;
3357            embed_tokens + lm_head + norm
3358        };
3359        Ok(elems * dtype.size_in_bytes())
3360    }
3361
3362    fn layer_sizes_in_bytes(
3363        &self,
3364        config: &str,
3365        dtype: DType,
3366        weight_pack_factor: usize,
3367        _matformer_config: Option<&MatformerSliceConfig>,
3368    ) -> Result<Vec<usize>> {
3369        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3370        let per_layer_elems = {
3371            let input_layernorm = cfg.hidden_size;
3372            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3373
3374            let size_in = cfg.hidden_size;
3375            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3376            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3377            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3378            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3379            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3380            let o_proj = size_q * size_in / weight_pack_factor;
3381
3382            let h_size = cfg.hidden_size;
3383            let i_size = cfg.intermediate_size;
3384            let gate_proj = h_size * i_size / weight_pack_factor;
3385            let up_proj = h_size * i_size / weight_pack_factor;
3386            let down_proj = i_size * h_size / weight_pack_factor;
3387
3388            input_layernorm
3389                + post_attention_layernorm
3390                + q_proj
3391                + k_proj
3392                + v_proj
3393                + o_proj
3394                + gate_proj
3395                + up_proj
3396                + down_proj
3397        };
3398        Ok(vec![
3399            per_layer_elems * dtype.size_in_bytes();
3400            cfg.num_hidden_layers
3401        ])
3402    }
3403
3404    fn num_layers(&self, config: &str) -> Result<usize> {
3405        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3406        Ok(cfg.num_hidden_layers)
3407    }
3408
3409    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3410        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3411
3412        let cfg = ModelConfigMetadata {
3413            max_seq_len: cfg.max_position_embeddings,
3414            num_layers: cfg.num_hidden_layers,
3415            hidden_size: cfg.hidden_size,
3416            num_kv_heads: cfg.num_key_value_heads,
3417            num_attn_heads: cfg.num_attention_heads,
3418            sliding_window: cfg.sliding_window,
3419            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3420            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3421        };
3422
3423        Ok(Box::new(cfg))
3424    }
3425}
3426
3427/// [`NormalLoader`] for a Qwen 3 MoE model.
3428///
3429/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3430pub struct Qwen3MoELoader;
3431
3432impl NormalModelLoader for Qwen3MoELoader {
3433    fn load(
3434        &self,
3435        config: &str,
3436        vb: ShardedVarBuilder,
3437        normal_loading_metadata: NormalLoadingMetadata,
3438        attention_mechanism: AttentionImplementation,
3439    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3440        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3441
3442        Ok(Box::new(models::qwen3_moe::Model::new(
3443            &cfg,
3444            vb,
3445            self.is_gptx(config)?,
3446            normal_loading_metadata,
3447            attention_mechanism,
3448        )?))
3449    }
3450    fn load_xlora(
3451        &self,
3452        _config: &str,
3453        _vb: ShardedVarBuilder,
3454        _lora_config: &[((String, String), LoraConfig)],
3455        _xlora_config: Option<XLoraConfig>,
3456        _xlora_ordering: Ordering,
3457        _normal_loading_metadata: NormalLoadingMetadata,
3458        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3459    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3460        todo!()
3461    }
3462    fn is_gptx(&self, _: &str) -> Result<bool> {
3463        Ok(true)
3464    }
3465    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3466        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3467
3468        Ok(Box::new(cfg))
3469    }
3470}
3471
3472impl IsqModelLoader for Qwen3MoELoader {
3473    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3474        Ok(vec![
3475            Regex::new(r"lm_head\.(weight|bias)$")?,
3476            // Attention
3477            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3478            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3479            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3480            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3481            // MLP
3482            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3483            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3484            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3485            // MLP MoE
3486            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3487            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3488            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3489        ])
3490    }
3491    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3492        self.isq_layer_regexes(config)
3493    }
3494    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3495        self.isq_layer_regexes_moqe(config)
3496    }
3497}
3498
3499impl DeviceMappedModelLoader for Qwen3MoELoader {
3500    fn mapped_max_act_size_elems(
3501        &self,
3502        config: &str,
3503        params: &AutoDeviceMapParams,
3504    ) -> Result<usize> {
3505        let AutoDeviceMapParams::Text {
3506            max_seq_len,
3507            max_batch_size,
3508        } = params
3509        else {
3510            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3511        };
3512
3513        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3514
3515        Ok(
3516            max_batch_size
3517                * cfg.num_attention_heads
3518                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3519        )
3520    }
3521    fn non_mapped_max_act_size_elems(
3522        &self,
3523        _config: &str,
3524        _params: &AutoDeviceMapParams,
3525    ) -> Result<usize> {
3526        Ok(0)
3527    }
3528
3529    fn non_mapped_size_in_bytes(
3530        &self,
3531        config: &str,
3532        dtype: DType,
3533        weight_pack_factor: usize,
3534        _matformer_config: Option<&MatformerSliceConfig>,
3535    ) -> Result<usize> {
3536        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3537        let elems = {
3538            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3539            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3540            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3541                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3542            } else {
3543                0
3544            };
3545            let norm = cfg.hidden_size;
3546            embed_tokens + lm_head + norm
3547        };
3548        Ok(elems * dtype.size_in_bytes())
3549    }
3550
3551    fn layer_sizes_in_bytes(
3552        &self,
3553        config: &str,
3554        dtype: DType,
3555        weight_pack_factor: usize,
3556        _matformer_config: Option<&MatformerSliceConfig>,
3557    ) -> Result<Vec<usize>> {
3558        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3559
3560        let mut layer_sizes_in_bytes = Vec::new();
3561        for layer_idx in 0..cfg.num_hidden_layers {
3562            let input_layernorm = cfg.hidden_size;
3563            let post_attention_layernorm = cfg.hidden_size;
3564
3565            let size_in = cfg.hidden_size;
3566            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3567            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3568            let q_proj = size_in * size_q / weight_pack_factor;
3569            let k_proj = size_in * size_kv / weight_pack_factor;
3570            let v_proj = size_in * size_kv / weight_pack_factor;
3571            let o_proj = size_q * size_in / weight_pack_factor;
3572
3573            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3574                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3575            {
3576                let gate_size = cfg.hidden_size * cfg.num_experts;
3577                let expert_size = {
3578                    let h_size = cfg.hidden_size;
3579                    let i_size = cfg.moe_intermediate_size;
3580                    let gate_proj = h_size * i_size / weight_pack_factor;
3581                    let up_proj = h_size * i_size / weight_pack_factor;
3582                    let down_proj = i_size * h_size / weight_pack_factor;
3583                    gate_proj + up_proj + down_proj
3584                };
3585                expert_size * cfg.num_experts + gate_size
3586            } else {
3587                let h_size = cfg.hidden_size;
3588                let i_size = cfg.intermediate_size;
3589                let gate_proj = h_size * i_size / weight_pack_factor;
3590                let up_proj = h_size * i_size / weight_pack_factor;
3591                let down_proj = i_size * h_size / weight_pack_factor;
3592                gate_proj + up_proj + down_proj
3593            };
3594
3595            let q_norm = cfg.head_dim();
3596            let k_norm = cfg.head_dim();
3597
3598            let size_elems = input_layernorm
3599                + post_attention_layernorm
3600                + q_proj
3601                + k_proj
3602                + v_proj
3603                + o_proj
3604                + mlp_size
3605                + q_norm
3606                + k_norm;
3607
3608            let size_in_bytes = size_elems * dtype.size_in_bytes();
3609            layer_sizes_in_bytes.push(size_in_bytes);
3610        }
3611
3612        Ok(layer_sizes_in_bytes)
3613    }
3614
3615    fn num_layers(&self, config: &str) -> Result<usize> {
3616        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3617        Ok(cfg.num_hidden_layers)
3618    }
3619
3620    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3621        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3622
3623        let cfg = ModelConfigMetadata {
3624            max_seq_len: cfg.max_position_embeddings,
3625            num_layers: cfg.num_hidden_layers,
3626            hidden_size: cfg.hidden_size,
3627            num_kv_heads: cfg.num_key_value_heads,
3628            num_attn_heads: cfg.num_attention_heads,
3629            sliding_window: cfg.sliding_window,
3630            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3631            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3632        };
3633
3634        Ok(Box::new(cfg))
3635    }
3636}
3637
3638// ======================== SmolLm3 loader
3639
3640/// [`NormalLoader`] for a SmolLm3 model.
3641///
3642/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3643pub struct SmolLm3Loader;
3644
3645impl NormalModelLoader for SmolLm3Loader {
3646    fn load(
3647        &self,
3648        config: &str,
3649        vb: ShardedVarBuilder,
3650        normal_loading_metadata: NormalLoadingMetadata,
3651        attention_mechanism: AttentionImplementation,
3652    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3653        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3654
3655        Ok(Box::new(models::smollm3::SmolLm3::new(
3656            &cfg,
3657            vb,
3658            self.is_gptx(config)?,
3659            normal_loading_metadata,
3660            attention_mechanism,
3661        )?))
3662    }
3663    fn load_xlora(
3664        &self,
3665        _config: &str,
3666        _vb: ShardedVarBuilder,
3667        _lora_config: &[((String, String), LoraConfig)],
3668        _xlora_config: Option<XLoraConfig>,
3669        _xlora_ordering: Ordering,
3670        _normal_loading_metadata: NormalLoadingMetadata,
3671        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3672    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3673        todo!()
3674    }
3675    fn is_gptx(&self, _: &str) -> Result<bool> {
3676        Ok(true)
3677    }
3678    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3679        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3680        Ok(Box::new(cfg))
3681    }
3682}
3683
3684impl IsqModelLoader for SmolLm3Loader {
3685    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3686        Ok(vec![
3687            Regex::new(r"lm_head\.(weight|bias)$")?,
3688            // Attention
3689            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3690            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3691            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3692            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3693            // MLP
3694            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3695            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3696            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3697        ])
3698    }
3699    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3700        self.isq_layer_regexes(config)
3701    }
3702}
3703
3704impl DeviceMappedModelLoader for SmolLm3Loader {
3705    fn mapped_max_act_size_elems(
3706        &self,
3707        config: &str,
3708        params: &AutoDeviceMapParams,
3709    ) -> Result<usize> {
3710        let AutoDeviceMapParams::Text {
3711            max_seq_len,
3712            max_batch_size,
3713        } = params
3714        else {
3715            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3716        };
3717
3718        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3719
3720        Ok(
3721            max_batch_size
3722                * cfg.num_attention_heads
3723                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3724        )
3725    }
3726    fn non_mapped_max_act_size_elems(
3727        &self,
3728        _config: &str,
3729        _params: &AutoDeviceMapParams,
3730    ) -> Result<usize> {
3731        Ok(0)
3732    }
3733
3734    fn non_mapped_size_in_bytes(
3735        &self,
3736        config: &str,
3737        dtype: DType,
3738        weight_pack_factor: usize,
3739        _matformer_config: Option<&MatformerSliceConfig>,
3740    ) -> Result<usize> {
3741        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3742
3743        let elems = {
3744            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3745            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3746            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3747                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3748            } else {
3749                0
3750            };
3751            let norm = cfg.hidden_size;
3752            embed_tokens + lm_head + norm
3753        };
3754        Ok(elems * dtype.size_in_bytes())
3755    }
3756
3757    fn layer_sizes_in_bytes(
3758        &self,
3759        config: &str,
3760        dtype: DType,
3761        weight_pack_factor: usize,
3762        _matformer_config: Option<&MatformerSliceConfig>,
3763    ) -> Result<Vec<usize>> {
3764        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3765
3766        let per_layer_elems = {
3767            let input_layernorm = cfg.hidden_size;
3768            let post_attention_layernorm = cfg.hidden_size;
3769
3770            let size_in = cfg.hidden_size;
3771            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3772            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3773            let q_proj = size_in * size_q / weight_pack_factor;
3774            let k_proj = size_in * size_kv / weight_pack_factor;
3775            let v_proj = size_in * size_kv / weight_pack_factor;
3776            let o_proj = size_q * size_in / weight_pack_factor;
3777
3778            let h_size = cfg.hidden_size;
3779            let i_size = cfg.intermediate_size;
3780            let gate_proj = h_size * i_size / weight_pack_factor;
3781            let up_proj = h_size * i_size / weight_pack_factor;
3782            let down_proj = i_size * h_size / weight_pack_factor;
3783
3784            input_layernorm
3785                + post_attention_layernorm
3786                + q_proj
3787                + k_proj
3788                + v_proj
3789                + o_proj
3790                + gate_proj
3791                + up_proj
3792                + down_proj
3793        };
3794        Ok(vec![
3795            per_layer_elems * dtype.size_in_bytes();
3796            cfg.num_hidden_layers
3797        ])
3798    }
3799
3800    fn num_layers(&self, config: &str) -> Result<usize> {
3801        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3802
3803        Ok(cfg.num_hidden_layers)
3804    }
3805    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3806        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3807
3808        let cfg = ModelConfigMetadata {
3809            max_seq_len: cfg.max_position_embeddings,
3810            num_layers: cfg.num_hidden_layers,
3811            hidden_size: cfg.hidden_size,
3812            num_kv_heads: cfg.num_key_value_heads,
3813            num_attn_heads: cfg.num_attention_heads,
3814            sliding_window: None,
3815            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3816            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3817        };
3818
3819        Ok(Box::new(cfg))
3820    }
3821}
3822
3823// ======================== GraniteMoeHybrid loader
3824
3825/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
3826///
3827/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3828pub struct GraniteMoeHybridLoader;
3829
3830impl NormalModelLoader for GraniteMoeHybridLoader {
3831    fn load(
3832        &self,
3833        config: &str,
3834        vb: ShardedVarBuilder,
3835        normal_loading_metadata: NormalLoadingMetadata,
3836        attention_mechanism: AttentionImplementation,
3837    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3838        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3839
3840        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
3841            &cfg,
3842            vb,
3843            self.is_gptx(config)?,
3844            normal_loading_metadata,
3845            attention_mechanism,
3846        )?))
3847    }
3848    fn load_xlora(
3849        &self,
3850        _config: &str,
3851        _vb: ShardedVarBuilder,
3852        _lora_config: &[((String, String), LoraConfig)],
3853        _xlora_config: Option<XLoraConfig>,
3854        _xlora_ordering: Ordering,
3855        _normal_loading_metadata: NormalLoadingMetadata,
3856        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3857    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3858        todo!()
3859    }
3860    fn is_gptx(&self, _: &str) -> Result<bool> {
3861        Ok(true)
3862    }
3863    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3864        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3865        Ok(Box::new(cfg))
3866    }
3867}
3868
3869impl IsqModelLoader for GraniteMoeHybridLoader {
3870    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3871        Ok(vec![
3872            Regex::new(r"lm_head\.(weight|bias)$")?,
3873            // Attention
3874            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3875            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3876            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3877            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3878            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
3879            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
3880            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
3881        ])
3882    }
3883    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3884        self.isq_layer_regexes(config)
3885    }
3886}
3887
3888impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
3889    fn mapped_max_act_size_elems(
3890        &self,
3891        config: &str,
3892        params: &AutoDeviceMapParams,
3893    ) -> Result<usize> {
3894        let AutoDeviceMapParams::Text {
3895            max_seq_len,
3896            max_batch_size,
3897        } = params
3898        else {
3899            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3900        };
3901
3902        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3903
3904        Ok(
3905            max_batch_size
3906                * cfg.num_attention_heads
3907                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3908        )
3909    }
3910    fn non_mapped_max_act_size_elems(
3911        &self,
3912        _config: &str,
3913        _params: &AutoDeviceMapParams,
3914    ) -> Result<usize> {
3915        Ok(0)
3916    }
3917
3918    fn non_mapped_size_in_bytes(
3919        &self,
3920        config: &str,
3921        dtype: DType,
3922        weight_pack_factor: usize,
3923        _matformer_config: Option<&MatformerSliceConfig>,
3924    ) -> Result<usize> {
3925        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3926
3927        let elems = {
3928            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3929            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3930            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3931                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3932            } else {
3933                0
3934            };
3935            let norm = cfg.hidden_size;
3936            embed_tokens + lm_head + norm
3937        };
3938        Ok(elems * dtype.size_in_bytes())
3939    }
3940
3941    fn layer_sizes_in_bytes(
3942        &self,
3943        config: &str,
3944        dtype: DType,
3945        weight_pack_factor: usize,
3946        _matformer_config: Option<&MatformerSliceConfig>,
3947    ) -> Result<Vec<usize>> {
3948        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3949
3950        let per_layer_elems = {
3951            let input_layernorm = cfg.hidden_size;
3952            let post_attention_layernorm = cfg.hidden_size;
3953
3954            let size_in = cfg.hidden_size;
3955            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3956            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
3957            let q_proj = size_in * size_q / weight_pack_factor;
3958            let k_proj = size_in * size_kv / weight_pack_factor;
3959            let v_proj = size_in * size_kv / weight_pack_factor;
3960            let o_proj = size_q * size_in / weight_pack_factor;
3961
3962            let h_size = cfg.hidden_size;
3963            let shared_i_size = cfg.shared_intermediate_size();
3964            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
3965            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
3966            let output_linear = shared_i_size * h_size / weight_pack_factor;
3967
3968            input_layernorm
3969                + post_attention_layernorm
3970                + q_proj
3971                + k_proj
3972                + v_proj
3973                + o_proj
3974                + input_linear
3975                + output_linear
3976        };
3977        Ok(vec![
3978            per_layer_elems * dtype.size_in_bytes();
3979            cfg.num_hidden_layers
3980        ])
3981    }
3982
3983    fn num_layers(&self, config: &str) -> Result<usize> {
3984        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3985
3986        Ok(cfg.num_hidden_layers)
3987    }
3988    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3989        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3990
3991        let cfg = ModelConfigMetadata {
3992            max_seq_len: cfg.max_position_embeddings,
3993            num_layers: cfg.num_hidden_layers,
3994            hidden_size: cfg.hidden_size,
3995            num_kv_heads: cfg.num_key_value_heads(),
3996            num_attn_heads: cfg.num_attention_heads,
3997            sliding_window: None,
3998            k_head_dim: cfg.head_dim(),
3999            v_head_dim: cfg.head_dim(),
4000        };
4001
4002        Ok(Box::new(cfg))
4003    }
4004}
4005
4006// ======================== GPT-OSS loader
4007
4008/// [`NormalLoader`] for a GPT-OSS model.
4009///
4010/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
4011pub struct GptOssLoader;
4012
4013impl NormalModelLoader for GptOssLoader {
4014    fn load(
4015        &self,
4016        config: &str,
4017        vb: ShardedVarBuilder,
4018        normal_loading_metadata: NormalLoadingMetadata,
4019        attention_mechanism: AttentionImplementation,
4020    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4021        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4022
4023        Ok(Box::new(models::gpt_oss::Model::new(
4024            &cfg,
4025            vb,
4026            self.is_gptx(config)?,
4027            normal_loading_metadata,
4028            attention_mechanism,
4029        )?))
4030    }
4031    fn load_xlora(
4032        &self,
4033        _config: &str,
4034        _vb: ShardedVarBuilder,
4035        _lora_config: &[((String, String), LoraConfig)],
4036        _xlora_config: Option<XLoraConfig>,
4037        _xlora_ordering: Ordering,
4038        _normal_loading_metadata: NormalLoadingMetadata,
4039        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4040    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4041        anyhow::bail!("GPT-OSS does not support X-LoRA")
4042    }
4043    fn is_gptx(&self, _: &str) -> Result<bool> {
4044        Ok(true)
4045    }
4046    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4047        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4048        Ok(Box::new(cfg))
4049    }
4050    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4051        Ok(false)
4052    }
4053}
4054
4055impl IsqModelLoader for GptOssLoader {
4056    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4057        // Only attention layers are ISQ-able - MoE experts are already MXFP4 quantized
4058        Ok(vec![
4059            Regex::new(r"lm_head\.(weight|bias)$")?,
4060            // Attention
4061            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4062            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4063            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4064            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4065        ])
4066    }
4067    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4068        self.isq_layer_regexes(config)
4069    }
4070}
4071
4072impl DeviceMappedModelLoader for GptOssLoader {
4073    fn mapped_max_act_size_elems(
4074        &self,
4075        config: &str,
4076        params: &AutoDeviceMapParams,
4077    ) -> Result<usize> {
4078        let AutoDeviceMapParams::Text {
4079            max_seq_len,
4080            max_batch_size,
4081        } = params
4082        else {
4083            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4084        };
4085
4086        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4087
4088        Ok(
4089            max_batch_size
4090                * cfg.num_attention_heads
4091                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4092        )
4093    }
4094    fn non_mapped_max_act_size_elems(
4095        &self,
4096        _config: &str,
4097        _params: &AutoDeviceMapParams,
4098    ) -> Result<usize> {
4099        Ok(0)
4100    }
4101
4102    fn non_mapped_size_in_bytes(
4103        &self,
4104        config: &str,
4105        dtype: DType,
4106        weight_pack_factor: usize,
4107        _matformer_config: Option<&MatformerSliceConfig>,
4108    ) -> Result<usize> {
4109        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4110
4111        let elems = {
4112            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4113            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4114                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4115            } else {
4116                0
4117            };
4118            let norm = cfg.hidden_size;
4119            embed_tokens + lm_head + norm
4120        };
4121        Ok(elems * dtype.size_in_bytes())
4122    }
4123
4124    fn layer_sizes_in_bytes(
4125        &self,
4126        config: &str,
4127        dtype: DType,
4128        weight_pack_factor: usize,
4129        _matformer_config: Option<&MatformerSliceConfig>,
4130    ) -> Result<Vec<usize>> {
4131        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4132
4133        let per_layer_elems = {
4134            let input_layernorm = cfg.hidden_size;
4135            let post_attention_layernorm = cfg.hidden_size;
4136
4137            let size_in = cfg.hidden_size;
4138            let head_dim = cfg.head_dim();
4139            let size_q = head_dim * cfg.num_attention_heads;
4140            let size_kv = head_dim * cfg.num_key_value_heads;
4141            let q_proj =
4142                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4143            let k_proj =
4144                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4145            let v_proj =
4146                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4147            let o_proj =
4148                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4149
4150            // MoE experts - MXFP4 quantized, so very compact
4151            // gate_up_proj: [num_experts, intermediate_size * 2, hidden_size/2] packed
4152            // down_proj: [num_experts, hidden_size, intermediate_size/2] packed
4153            // At 4 bits per weight, packing factor is 2
4154            let mxfp4_pack = 2;
4155            let gate_up_proj_size =
4156                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4157            let down_proj_size =
4158                cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4159            // Plus scales at 1 byte per 32 elements
4160            let gate_up_scales =
4161                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4162            let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4163            // Plus biases
4164            let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4165            let down_bias = cfg.num_local_experts * cfg.hidden_size;
4166            // Router
4167            let router = cfg.hidden_size * cfg.num_local_experts;
4168            // Sinks per head
4169            let sinks = cfg.num_attention_heads;
4170
4171            input_layernorm
4172                + post_attention_layernorm
4173                + q_proj
4174                + k_proj
4175                + v_proj
4176                + o_proj
4177                + gate_up_proj_size
4178                + down_proj_size
4179                + gate_up_scales
4180                + down_scales
4181                + gate_up_bias
4182                + down_bias
4183                + router
4184                + sinks
4185        };
4186        Ok(vec![
4187            per_layer_elems * dtype.size_in_bytes();
4188            cfg.num_hidden_layers
4189        ])
4190    }
4191
4192    fn num_layers(&self, config: &str) -> Result<usize> {
4193        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4194
4195        Ok(cfg.num_hidden_layers)
4196    }
4197    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4198        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4199
4200        let head_dim = cfg.head_dim();
4201        let cfg = ModelConfigMetadata {
4202            max_seq_len: cfg.max_position_embeddings,
4203            num_layers: cfg.num_hidden_layers,
4204            hidden_size: cfg.hidden_size,
4205            num_kv_heads: cfg.num_key_value_heads,
4206            num_attn_heads: cfg.num_attention_heads,
4207            sliding_window: cfg.sliding_window,
4208            k_head_dim: head_dim,
4209            v_head_dim: head_dim,
4210        };
4211
4212        Ok(Box::new(cfg))
4213    }
4214}