mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "qwen3moe")]
175    Qwen3Moe,
176    #[serde(rename = "smollm3")]
177    SmolLm3,
178}
179
180// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
181impl NormalLoaderType {
182    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183        match name {
184            "MistralForCausalLM" => Ok(Self::Mistral),
185            "MixtralForCausalLM" => Ok(Self::Mixtral),
186            "GemmaForCausalLM" => Ok(Self::Gemma),
187            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
188            "PhiForCausalLM" => Ok(Self::Phi2),
189            "Phi3ForCausalLM" => Ok(Self::Phi3),
190            "LlamaForCausalLM" => Ok(Self::Llama),
191            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
192            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
193            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
194            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
195            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
196            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
197            "Glm4ForCausalLM" => Ok(Self::GLM4),
198            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
199            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
200            other => anyhow::bail!(
201                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
202            ),
203        }
204    }
205}
206
207impl FromStr for NormalLoaderType {
208    type Err = String;
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        match s {
211            "mistral" => Ok(Self::Mistral),
212            "gemma" => Ok(Self::Gemma),
213            "mixtral" => Ok(Self::Mixtral),
214            "llama" => Ok(Self::Llama),
215            "phi2" => Ok(Self::Phi2),
216            "phi3" => Ok(Self::Phi3),
217            "qwen2" => Ok(Self::Qwen2),
218            "gemma2" => Ok(Self::Gemma2),
219            "starcoder2" => Ok(Self::Starcoder2),
220            "phi3.5moe" => Ok(Self::Phi3_5MoE),
221            "deepseekv2" => Ok(Self::DeepSeekV2),
222            "deepseekv3" => Ok(Self::DeepSeekV3),
223            "qwen3" => Ok(Self::Qwen3),
224            "glm4" => Ok(Self::GLM4),
225            "qwen3moe" => Ok(Self::Qwen3Moe),
226            "smollm3" => Ok(Self::SmolLm3),
227            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
228        }
229    }
230}
231
232impl Display for NormalLoaderType {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            Self::Gemma => write!(f, "gemma"),
236            Self::Gemma2 => write!(f, "gemma2"),
237            Self::Llama => write!(f, "llama"),
238            Self::Mistral => write!(f, "mistral"),
239            Self::Mixtral => write!(f, "mixtral"),
240            Self::Phi2 => write!(f, "phi2"),
241            Self::Phi3 => write!(f, "phi3"),
242            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
243            Self::Qwen2 => write!(f, "qwen2"),
244            Self::Starcoder2 => write!(f, "starcoder2"),
245            Self::DeepSeekV2 => write!(f, "deepseekv2"),
246            Self::DeepSeekV3 => write!(f, "deepseekv3"),
247            Self::Qwen3 => write!(f, "qwen3"),
248            Self::GLM4 => write!(f, "glm4"),
249            Self::Qwen3Moe => write!(f, "qwen3moe"),
250            Self::SmolLm3 => write!(f, "smollm3"),
251        }
252    }
253}
254
255macro_rules! bias_if {
256    ($cond:expr, $size:expr) => {
257        if $cond {
258            $size
259        } else {
260            0
261        }
262    };
263}
264
265/// Load a model based on the Hugging Face Transformers -CausalLM model class
266pub struct AutoNormalLoader;
267
268#[derive(Deserialize)]
269struct AutoNormalLoaderConfig {
270    architectures: Vec<String>,
271}
272
273impl AutoNormalLoader {
274    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
275        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
276        if auto_cfg.architectures.len() != 1 {
277            anyhow::bail!("Expected to have one name for `architectures` config field.")
278        }
279
280        let name = &auto_cfg.architectures[0];
281
282        let tp = NormalLoaderType::from_causal_lm_name(name)?;
283
284        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
285
286        match tp {
287            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
288            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
289            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
290            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
291            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
292            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
293            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
294            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
295            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
296            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
297            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
298            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
299            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
300            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
301            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
302            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
303        }
304    }
305}
306
307impl NormalModelLoader for AutoNormalLoader {
308    fn load(
309        &self,
310        config: &str,
311        vb: ShardedVarBuilder,
312        normal_loading_metadata: NormalLoadingMetadata,
313        attention_mechanism: AttentionImplementation,
314    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
315        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
316    }
317    fn load_xlora(
318        &self,
319        config: &str,
320        vb: ShardedVarBuilder,
321        lora_config: &[((String, String), LoraConfig)],
322        xlora_config: Option<XLoraConfig>,
323        xlora_ordering: Ordering,
324        normal_loading_metadata: NormalLoadingMetadata,
325        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
326    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327        Self::get_loader(config)?.load_xlora(
328            config,
329            vb,
330            lora_config,
331            xlora_config,
332            xlora_ordering,
333            normal_loading_metadata,
334            preload_adapters,
335        )
336    }
337    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
338        Self::get_loader(config)?.get_config_repr(config)
339    }
340    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
341        Self::get_loader(config)?.supports_paged_attention(config)
342    }
343    fn is_gptx(&self, config: &str) -> Result<bool> {
344        Self::get_loader(config)?.is_gptx(config)
345    }
346}
347
348impl IsqModelLoader for AutoNormalLoader {
349    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
350        Self::get_loader(config)?.immediate_isq_predicates(config)
351    }
352    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
353        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
354    }
355    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
356        Self::get_loader(config)?.isq_layer_regexes(config)
357    }
358    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
360    }
361}
362
363impl DeviceMappedModelLoader for AutoNormalLoader {
364    fn non_mapped_size_in_bytes(
365        &self,
366        config: &str,
367        dtype: DType,
368        weight_pack_factor: usize,
369        _matformer_config: Option<&MatformerSliceConfig>,
370    ) -> Result<usize> {
371        Self::get_loader(config)?.non_mapped_size_in_bytes(
372            config,
373            dtype,
374            weight_pack_factor,
375            _matformer_config,
376        )
377    }
378    fn num_layers(&self, config: &str) -> Result<usize> {
379        Self::get_loader(config)?.num_layers(config)
380    }
381    fn layer_sizes_in_bytes(
382        &self,
383        config: &str,
384        dtype: DType,
385        weight_pack_factor: usize,
386        _matformer_config: Option<&MatformerSliceConfig>,
387    ) -> Result<Vec<usize>> {
388        Self::get_loader(config)?.layer_sizes_in_bytes(
389            config,
390            dtype,
391            weight_pack_factor,
392            _matformer_config,
393        )
394    }
395    fn mapped_max_act_size_elems(
396        &self,
397        config: &str,
398        params: &super::AutoDeviceMapParams,
399    ) -> Result<usize> {
400        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
401    }
402    fn non_mapped_max_act_size_elems(
403        &self,
404        _config: &str,
405        _params: &AutoDeviceMapParams,
406    ) -> Result<usize> {
407        Ok(0)
408    }
409    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
410        Self::get_loader(config)?.model_config(config)
411    }
412}
413
414// ======================== Mistral loader
415
416pub struct MistralLoader;
417
418impl NormalModelLoader for MistralLoader {
419    fn load(
420        &self,
421        config: &str,
422        vb: ShardedVarBuilder,
423        normal_loading_metadata: NormalLoadingMetadata,
424        attention_mechanism: AttentionImplementation,
425    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
426        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
427        Ok(Box::new(models::mistral::Model::new(
428            &cfg,
429            vb,
430            self.is_gptx(config)?,
431            normal_loading_metadata,
432            attention_mechanism,
433        )?))
434    }
435    fn load_xlora(
436        &self,
437        config: &str,
438        vb: ShardedVarBuilder,
439        lora_config: &[((String, String), LoraConfig)],
440        xlora_config: Option<XLoraConfig>,
441        xlora_ordering: Ordering,
442        normal_loading_metadata: NormalLoadingMetadata,
443        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
444    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
445        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
446        Ok(Box::new(xlora_models::XLoraMistral::new(
447            &cfg,
448            vb,
449            lora_config,
450            xlora_config,
451            xlora_ordering,
452            self.is_gptx(config)?,
453            normal_loading_metadata,
454            preload_adapters,
455        )?))
456    }
457    fn is_gptx(&self, _: &str) -> Result<bool> {
458        Ok(true)
459    }
460    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
461        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
462        Ok(Box::new(cfg))
463    }
464}
465
466impl IsqModelLoader for MistralLoader {
467    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
468        Ok(vec![
469            Regex::new(r"lm_head\.(weight|bias)$")?,
470            // Attention
471            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
472            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
473            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
474            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
475            // MLP
476            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
477            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
478            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
479        ])
480    }
481    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
482        self.isq_layer_regexes(config)
483    }
484}
485
486impl DeviceMappedModelLoader for MistralLoader {
487    fn mapped_max_act_size_elems(
488        &self,
489        config: &str,
490        params: &AutoDeviceMapParams,
491    ) -> Result<usize> {
492        let AutoDeviceMapParams::Text {
493            max_seq_len,
494            max_batch_size,
495        } = params
496        else {
497            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
498        };
499
500        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
501
502        Ok(
503            max_batch_size
504                * cfg.num_attention_heads
505                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
506        )
507    }
508    fn non_mapped_max_act_size_elems(
509        &self,
510        _config: &str,
511        _params: &AutoDeviceMapParams,
512    ) -> Result<usize> {
513        Ok(0)
514    }
515
516    fn non_mapped_size_in_bytes(
517        &self,
518        config: &str,
519        dtype: DType,
520        weight_pack_factor: usize,
521        _matformer_config: Option<&MatformerSliceConfig>,
522    ) -> Result<usize> {
523        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
524
525        let elems = {
526            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
527            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
528            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
529                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
530            } else {
531                0
532            };
533            let norm = cfg.hidden_size;
534            embed_tokens + lm_head + norm
535        };
536        Ok(elems * dtype.size_in_bytes())
537    }
538
539    fn layer_sizes_in_bytes(
540        &self,
541        config: &str,
542        dtype: DType,
543        weight_pack_factor: usize,
544        _matformer_config: Option<&MatformerSliceConfig>,
545    ) -> Result<Vec<usize>> {
546        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
547
548        let per_layer_elems = {
549            let input_layernorm = cfg.hidden_size;
550            let post_attention_layernorm = cfg.hidden_size;
551
552            let size_in = cfg.hidden_size;
553            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
554            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
555            let q_proj = size_in * size_q / weight_pack_factor;
556            let k_proj = size_in * size_kv / weight_pack_factor;
557            let v_proj = size_in * size_kv / weight_pack_factor;
558            let o_proj = size_q * size_in / weight_pack_factor;
559
560            let h_size = cfg.hidden_size;
561            let i_size = cfg.intermediate_size;
562            let gate_proj = h_size * i_size / weight_pack_factor;
563            let up_proj = h_size * i_size / weight_pack_factor;
564            let down_proj = i_size * h_size / weight_pack_factor;
565
566            input_layernorm
567                + post_attention_layernorm
568                + q_proj
569                + k_proj
570                + v_proj
571                + o_proj
572                + gate_proj
573                + up_proj
574                + down_proj
575        };
576        Ok(vec![
577            per_layer_elems * dtype.size_in_bytes();
578            cfg.num_hidden_layers
579        ])
580    }
581
582    fn num_layers(&self, config: &str) -> Result<usize> {
583        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
584        Ok(cfg.num_hidden_layers)
585    }
586
587    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
588        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
589
590        let cfg = ModelConfigMetadata {
591            max_seq_len: cfg.max_position_embeddings,
592            num_layers: cfg.num_hidden_layers,
593            hidden_size: cfg.hidden_size,
594            num_kv_heads: cfg.num_key_value_heads,
595            num_attn_heads: cfg.num_attention_heads,
596            sliding_window: cfg.sliding_window,
597            k_head_dim: cfg.head_dim(),
598            v_head_dim: cfg.head_dim(),
599        };
600
601        Ok(Box::new(cfg))
602    }
603}
604
605// ======================== Gemma loader
606
607/// [`NormalLoader`] for a Gemma model.
608///
609/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
610pub struct GemmaLoader;
611
612impl NormalModelLoader for GemmaLoader {
613    fn load(
614        &self,
615        config: &str,
616        vb: ShardedVarBuilder,
617        normal_loading_metadata: NormalLoadingMetadata,
618        attention_mechanism: AttentionImplementation,
619    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
620        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
621
622        Ok(Box::new(models::gemma::Model::new(
623            &cfg,
624            vb,
625            self.is_gptx(config)?,
626            normal_loading_metadata,
627            attention_mechanism,
628        )?))
629    }
630    fn load_xlora(
631        &self,
632        config: &str,
633        vb: ShardedVarBuilder,
634        lora_config: &[((String, String), LoraConfig)],
635        xlora_config: Option<XLoraConfig>,
636        xlora_ordering: Ordering,
637        normal_loading_metadata: NormalLoadingMetadata,
638        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
639    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
640        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
641
642        Ok(Box::new(xlora_models::XLoraGemma::new(
643            &cfg,
644            vb,
645            lora_config,
646            xlora_config,
647            xlora_ordering,
648            self.is_gptx(config)?,
649            normal_loading_metadata,
650            preload_adapters,
651        )?))
652    }
653    fn is_gptx(&self, _: &str) -> Result<bool> {
654        Ok(true)
655    }
656    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
657        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
658        Ok(Box::new(cfg))
659    }
660}
661
662impl IsqModelLoader for GemmaLoader {
663    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
664        Ok(vec![
665            Regex::new(r"lm_head\.(weight|bias)$")?,
666            // Attention
667            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
668            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
669            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
670            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
671            // MLP
672            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
673            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
674            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
675        ])
676    }
677    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
678        self.isq_layer_regexes(config)
679    }
680}
681
682impl DeviceMappedModelLoader for GemmaLoader {
683    fn mapped_max_act_size_elems(
684        &self,
685        config: &str,
686        params: &AutoDeviceMapParams,
687    ) -> Result<usize> {
688        let AutoDeviceMapParams::Text {
689            max_seq_len,
690            max_batch_size,
691        } = params
692        else {
693            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
694        };
695
696        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
697
698        Ok(
699            max_batch_size
700                * cfg.num_attention_heads
701                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
702        )
703    }
704    fn non_mapped_max_act_size_elems(
705        &self,
706        _config: &str,
707        _params: &AutoDeviceMapParams,
708    ) -> Result<usize> {
709        Ok(0)
710    }
711
712    fn non_mapped_size_in_bytes(
713        &self,
714        config: &str,
715        dtype: DType,
716        weight_pack_factor: usize,
717        _matformer_config: Option<&MatformerSliceConfig>,
718    ) -> Result<usize> {
719        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
720
721        let elems = {
722            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
723            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
724            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
725                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
726            } else {
727                0
728            };
729            let norm = cfg.hidden_size;
730            embed_tokens + lm_head + norm
731        };
732        Ok(elems * dtype.size_in_bytes())
733    }
734
735    fn layer_sizes_in_bytes(
736        &self,
737        config: &str,
738        dtype: DType,
739        weight_pack_factor: usize,
740        _matformer_config: Option<&MatformerSliceConfig>,
741    ) -> Result<Vec<usize>> {
742        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
743
744        let per_layer_elems = {
745            let input_layernorm = cfg.hidden_size;
746            let post_attention_layernorm = cfg.hidden_size;
747
748            let size_in = cfg.hidden_size;
749            let size_q = cfg.head_dim * cfg.num_attention_heads;
750            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
751            let q_proj =
752                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
753            let k_proj =
754                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
755            let v_proj =
756                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
757            let o_proj =
758                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
759
760            let h_size = cfg.hidden_size;
761            let i_size = cfg.intermediate_size;
762            let gate_proj = h_size * i_size / weight_pack_factor;
763            let up_proj = h_size * i_size / weight_pack_factor;
764            let down_proj = i_size * h_size / weight_pack_factor;
765
766            input_layernorm
767                + post_attention_layernorm
768                + q_proj
769                + k_proj
770                + v_proj
771                + o_proj
772                + gate_proj
773                + up_proj
774                + down_proj
775        };
776        Ok(vec![
777            per_layer_elems * dtype.size_in_bytes();
778            cfg.num_hidden_layers
779        ])
780    }
781
782    fn num_layers(&self, config: &str) -> Result<usize> {
783        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
784        Ok(cfg.num_hidden_layers)
785    }
786
787    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
788        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
789
790        let cfg = ModelConfigMetadata {
791            max_seq_len: cfg.max_position_embeddings,
792            num_layers: cfg.num_hidden_layers,
793            hidden_size: cfg.hidden_size,
794            num_kv_heads: cfg.num_key_value_heads,
795            num_attn_heads: cfg.num_attention_heads,
796            sliding_window: None,
797            k_head_dim: cfg.head_dim,
798            v_head_dim: cfg.head_dim,
799        };
800
801        Ok(Box::new(cfg))
802    }
803}
804
805// ======================== Llama loader
806
807/// [`NormalLoader`] for a Llama model.
808///
809/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
810pub struct LlamaLoader;
811
812impl NormalModelLoader for LlamaLoader {
813    fn load(
814        &self,
815        config: &str,
816        vb: ShardedVarBuilder,
817        normal_loading_metadata: NormalLoadingMetadata,
818        attention_mechanism: AttentionImplementation,
819    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
820        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
821
822        Ok(Box::new(models::llama::Llama::new(
823            &cfg,
824            vb,
825            self.is_gptx(config)?,
826            normal_loading_metadata,
827            attention_mechanism,
828        )?))
829    }
830    fn load_xlora(
831        &self,
832        config: &str,
833        vb: ShardedVarBuilder,
834        lora_config: &[((String, String), LoraConfig)],
835        xlora_config: Option<XLoraConfig>,
836        xlora_ordering: Ordering,
837        normal_loading_metadata: NormalLoadingMetadata,
838        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
839    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
840        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
841
842        Ok(Box::new(xlora_models::XLoraLlama::new(
843            &cfg,
844            vb,
845            lora_config,
846            xlora_config,
847            xlora_ordering,
848            self.is_gptx(config)?,
849            normal_loading_metadata,
850            preload_adapters,
851        )?))
852    }
853    fn is_gptx(&self, _: &str) -> Result<bool> {
854        Ok(true)
855    }
856    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
857        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
858        Ok(Box::new(cfg))
859    }
860}
861
862impl IsqModelLoader for LlamaLoader {
863    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
864        Ok(vec![
865            Regex::new(r"lm_head\.(weight|bias)$")?,
866            // Attention
867            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
868            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
869            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
870            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
871            // MLP
872            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
873            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
874            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
875        ])
876    }
877    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
878        self.isq_layer_regexes(config)
879    }
880}
881
882impl DeviceMappedModelLoader for LlamaLoader {
883    fn mapped_max_act_size_elems(
884        &self,
885        config: &str,
886        params: &AutoDeviceMapParams,
887    ) -> Result<usize> {
888        let AutoDeviceMapParams::Text {
889            max_seq_len,
890            max_batch_size,
891        } = params
892        else {
893            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
894        };
895
896        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
897
898        Ok(
899            max_batch_size
900                * cfg.num_attention_heads
901                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
902        )
903    }
904    fn non_mapped_max_act_size_elems(
905        &self,
906        _config: &str,
907        _params: &AutoDeviceMapParams,
908    ) -> Result<usize> {
909        Ok(0)
910    }
911
912    fn non_mapped_size_in_bytes(
913        &self,
914        config: &str,
915        dtype: DType,
916        weight_pack_factor: usize,
917        _matformer_config: Option<&MatformerSliceConfig>,
918    ) -> Result<usize> {
919        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
920
921        let elems = {
922            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
923            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
924            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
925                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
926            } else {
927                0
928            };
929            let norm = cfg.hidden_size;
930            embed_tokens + lm_head + norm
931        };
932        Ok(elems * dtype.size_in_bytes())
933    }
934
935    fn layer_sizes_in_bytes(
936        &self,
937        config: &str,
938        dtype: DType,
939        weight_pack_factor: usize,
940        _matformer_config: Option<&MatformerSliceConfig>,
941    ) -> Result<Vec<usize>> {
942        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
943
944        let per_layer_elems = {
945            let input_layernorm = cfg.hidden_size;
946            let post_attention_layernorm = cfg.hidden_size;
947
948            let size_in = cfg.hidden_size;
949            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
950            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
951            let q_proj = size_in * size_q / weight_pack_factor;
952            let k_proj = size_in * size_kv / weight_pack_factor;
953            let v_proj = size_in * size_kv / weight_pack_factor;
954            let o_proj = size_q * size_in / weight_pack_factor;
955
956            let h_size = cfg.hidden_size;
957            let i_size = cfg.intermediate_size;
958            let gate_proj = h_size * i_size / weight_pack_factor;
959            let up_proj = h_size * i_size / weight_pack_factor;
960            let down_proj = i_size * h_size / weight_pack_factor;
961
962            input_layernorm
963                + post_attention_layernorm
964                + q_proj
965                + k_proj
966                + v_proj
967                + o_proj
968                + gate_proj
969                + up_proj
970                + down_proj
971        };
972        Ok(vec![
973            per_layer_elems * dtype.size_in_bytes();
974            cfg.num_hidden_layers
975        ])
976    }
977
978    fn num_layers(&self, config: &str) -> Result<usize> {
979        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
980
981        Ok(cfg.num_hidden_layers)
982    }
983    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
984        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
985
986        let cfg = ModelConfigMetadata {
987            max_seq_len: cfg.max_position_embeddings,
988            num_layers: cfg.num_hidden_layers,
989            hidden_size: cfg.hidden_size,
990            num_kv_heads: cfg.num_key_value_heads,
991            num_attn_heads: cfg.num_attention_heads,
992            sliding_window: None,
993            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
994            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
995        };
996
997        Ok(Box::new(cfg))
998    }
999}
1000
1001// ======================== Mixtral loader
1002
1003pub struct MixtralLoader;
1004
1005impl NormalModelLoader for MixtralLoader {
1006    fn load(
1007        &self,
1008        config: &str,
1009        vb: ShardedVarBuilder,
1010        normal_loading_metadata: NormalLoadingMetadata,
1011        attention_mechanism: AttentionImplementation,
1012    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1013        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1014
1015        Ok(Box::new(models::mixtral::Model::new(
1016            &cfg,
1017            vb,
1018            self.is_gptx(config)?,
1019            normal_loading_metadata,
1020            attention_mechanism,
1021        )?))
1022    }
1023    fn load_xlora(
1024        &self,
1025        config: &str,
1026        vb: ShardedVarBuilder,
1027        lora_config: &[((String, String), LoraConfig)],
1028        xlora_config: Option<XLoraConfig>,
1029        xlora_ordering: Ordering,
1030        normal_loading_metadata: NormalLoadingMetadata,
1031        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1032    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1033        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1034
1035        Ok(Box::new(xlora_models::XLoraMixtral::new(
1036            &cfg,
1037            vb,
1038            lora_config,
1039            xlora_config,
1040            xlora_ordering,
1041            self.is_gptx(config)?,
1042            normal_loading_metadata,
1043            preload_adapters,
1044        )?))
1045    }
1046    fn is_gptx(&self, _: &str) -> Result<bool> {
1047        Ok(true)
1048    }
1049    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1050        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1051
1052        Ok(Box::new(cfg))
1053    }
1054}
1055
1056impl IsqModelLoader for MixtralLoader {
1057    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1058        Ok(vec![
1059            Regex::new(r"lm_head\.(weight|bias)$")?,
1060            // Attention
1061            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1062            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1063            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1064            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1065            // Experts
1066            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1067            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1068            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1069            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1070        ])
1071    }
1072    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1073        self.isq_layer_regexes(config)
1074    }
1075}
1076
1077impl DeviceMappedModelLoader for MixtralLoader {
1078    fn mapped_max_act_size_elems(
1079        &self,
1080        config: &str,
1081        params: &AutoDeviceMapParams,
1082    ) -> Result<usize> {
1083        let AutoDeviceMapParams::Text {
1084            max_seq_len,
1085            max_batch_size,
1086        } = params
1087        else {
1088            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1089        };
1090
1091        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1092
1093        Ok(
1094            max_batch_size
1095                * cfg.num_attention_heads
1096                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1097        )
1098    }
1099    fn non_mapped_max_act_size_elems(
1100        &self,
1101        _config: &str,
1102        _params: &AutoDeviceMapParams,
1103    ) -> Result<usize> {
1104        Ok(0)
1105    }
1106
1107    fn non_mapped_size_in_bytes(
1108        &self,
1109        config: &str,
1110        dtype: DType,
1111        weight_pack_factor: usize,
1112        _matformer_config: Option<&MatformerSliceConfig>,
1113    ) -> Result<usize> {
1114        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1115
1116        let elems = {
1117            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1118            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1119            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1120                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1121            } else {
1122                0
1123            };
1124            let norm = cfg.hidden_size;
1125            embed_tokens + lm_head + norm
1126        };
1127        Ok(elems * dtype.size_in_bytes())
1128    }
1129
1130    fn layer_sizes_in_bytes(
1131        &self,
1132        config: &str,
1133        dtype: DType,
1134        weight_pack_factor: usize,
1135        _matformer_config: Option<&MatformerSliceConfig>,
1136    ) -> Result<Vec<usize>> {
1137        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1138
1139        let per_layer_elems = {
1140            let input_layernorm = cfg.hidden_size;
1141            let post_attention_layernorm = cfg.hidden_size;
1142
1143            let size_in = cfg.hidden_size;
1144            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1145            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1146            let q_proj = size_in * size_q / weight_pack_factor;
1147            let k_proj = size_in * size_kv / weight_pack_factor;
1148            let v_proj = size_in * size_kv / weight_pack_factor;
1149            let o_proj = size_q * size_in / weight_pack_factor;
1150
1151            let moe_block = {
1152                let gate = cfg.hidden_size * cfg.num_local_experts;
1153                // Assume quantizing weight pack factor
1154                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1155                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1156                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1157                gate + cfg.num_local_experts * w1
1158                    + cfg.num_local_experts * w2
1159                    + cfg.num_local_experts * w3
1160            };
1161
1162            input_layernorm
1163                + post_attention_layernorm
1164                + q_proj
1165                + k_proj
1166                + v_proj
1167                + o_proj
1168                + moe_block
1169        };
1170        Ok(vec![
1171            per_layer_elems * dtype.size_in_bytes();
1172            cfg.num_hidden_layers
1173        ])
1174    }
1175
1176    fn num_layers(&self, config: &str) -> Result<usize> {
1177        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1178
1179        Ok(cfg.num_hidden_layers)
1180    }
1181
1182    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1183        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1184
1185        let cfg = ModelConfigMetadata {
1186            max_seq_len: cfg.max_position_embeddings,
1187            num_layers: cfg.num_hidden_layers,
1188            hidden_size: cfg.hidden_size,
1189            num_kv_heads: cfg.num_key_value_heads,
1190            num_attn_heads: cfg.num_attention_heads,
1191            sliding_window: cfg.sliding_window,
1192            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1193            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1194        };
1195
1196        Ok(Box::new(cfg))
1197    }
1198}
1199
1200// ======================== Phi2 loader
1201
1202/// [`NormalLoader`] for a Phi 2 model.
1203///
1204/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1205pub struct Phi2Loader;
1206
1207impl NormalModelLoader for Phi2Loader {
1208    fn load(
1209        &self,
1210        config: &str,
1211        vb: ShardedVarBuilder,
1212        normal_loading_metadata: NormalLoadingMetadata,
1213        attention_mechanism: AttentionImplementation,
1214    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1215        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1216
1217        Ok(Box::new(models::phi2::Model::new(
1218            &cfg,
1219            vb,
1220            self.is_gptx(config)?,
1221            normal_loading_metadata,
1222            attention_mechanism,
1223        )?))
1224    }
1225    fn load_xlora(
1226        &self,
1227        config: &str,
1228        vb: ShardedVarBuilder,
1229        lora_config: &[((String, String), LoraConfig)],
1230        xlora_config: Option<XLoraConfig>,
1231        xlora_ordering: Ordering,
1232        normal_loading_metadata: NormalLoadingMetadata,
1233        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1234    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1235        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1236
1237        Ok(Box::new(xlora_models::XLoraPhi2::new(
1238            &cfg,
1239            vb,
1240            lora_config,
1241            xlora_config,
1242            xlora_ordering,
1243            self.is_gptx(config)?,
1244            normal_loading_metadata,
1245            preload_adapters,
1246        )?))
1247    }
1248    fn is_gptx(&self, _: &str) -> Result<bool> {
1249        Ok(true)
1250    }
1251    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1252        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1253
1254        Ok(Box::new(cfg))
1255    }
1256}
1257
1258impl IsqModelLoader for Phi2Loader {
1259    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1260        Ok(vec![
1261            Regex::new(r"lm_head\.(weight|bias)$")?,
1262            // Attention
1263            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1264            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1265            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1266            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1267            // MLP
1268            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1269            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1270        ])
1271    }
1272    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1273        self.isq_layer_regexes(config)
1274    }
1275}
1276
1277impl DeviceMappedModelLoader for Phi2Loader {
1278    fn mapped_max_act_size_elems(
1279        &self,
1280        config: &str,
1281        params: &AutoDeviceMapParams,
1282    ) -> Result<usize> {
1283        let AutoDeviceMapParams::Text {
1284            max_seq_len,
1285            max_batch_size,
1286        } = params
1287        else {
1288            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1289        };
1290
1291        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1292
1293        Ok(
1294            max_batch_size
1295                * cfg.num_attention_heads
1296                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1297        )
1298    }
1299    fn non_mapped_max_act_size_elems(
1300        &self,
1301        _config: &str,
1302        _params: &AutoDeviceMapParams,
1303    ) -> Result<usize> {
1304        Ok(0)
1305    }
1306
1307    fn non_mapped_size_in_bytes(
1308        &self,
1309        config: &str,
1310        dtype: DType,
1311        weight_pack_factor: usize,
1312        _matformer_config: Option<&MatformerSliceConfig>,
1313    ) -> Result<usize> {
1314        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1315
1316        let elems = {
1317            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1318            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1319            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1320                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1321            } else {
1322                0
1323            };
1324            let norm = cfg.hidden_size;
1325            embed_tokens + lm_head + norm
1326        };
1327        Ok(elems * dtype.size_in_bytes())
1328    }
1329
1330    fn layer_sizes_in_bytes(
1331        &self,
1332        config: &str,
1333        dtype: DType,
1334        weight_pack_factor: usize,
1335        _matformer_config: Option<&MatformerSliceConfig>,
1336    ) -> Result<Vec<usize>> {
1337        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1338
1339        let per_layer_elems = {
1340            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1341
1342            let size_in = cfg.hidden_size;
1343            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1344            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1345            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1346            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1347            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1348            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1349            let (q_norm, k_norm) = if cfg.qk_layernorm {
1350                (cfg.head_dim(), cfg.head_dim())
1351            } else {
1352                (0, 0)
1353            };
1354
1355            let h_size = cfg.hidden_size;
1356            let i_size = cfg.intermediate_size;
1357            let fc1 = h_size * i_size / weight_pack_factor;
1358            let fc2 = h_size * i_size / weight_pack_factor;
1359
1360            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1361        };
1362        Ok(vec![
1363            per_layer_elems * dtype.size_in_bytes();
1364            cfg.num_hidden_layers
1365        ])
1366    }
1367
1368    fn num_layers(&self, config: &str) -> Result<usize> {
1369        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1370
1371        Ok(cfg.num_hidden_layers)
1372    }
1373
1374    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1375        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377        let cfg = ModelConfigMetadata {
1378            max_seq_len: cfg.max_position_embeddings,
1379            num_layers: cfg.num_hidden_layers,
1380            hidden_size: cfg.hidden_size,
1381            num_kv_heads: cfg.num_key_value_heads(),
1382            num_attn_heads: cfg.num_attention_heads,
1383            sliding_window: None,
1384            k_head_dim: cfg.head_dim(),
1385            v_head_dim: cfg.head_dim(),
1386        };
1387
1388        Ok(Box::new(cfg))
1389    }
1390}
1391
1392// ======================== Phi3 loader
1393
1394/// [`NormalLoader`] for a Phi 3 model.
1395///
1396/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1397pub struct Phi3Loader;
1398
1399impl NormalModelLoader for Phi3Loader {
1400    fn load(
1401        &self,
1402        config: &str,
1403        vb: ShardedVarBuilder,
1404        normal_loading_metadata: NormalLoadingMetadata,
1405        attention_mechanism: AttentionImplementation,
1406    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1407        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1408
1409        Ok(Box::new(models::phi3::Model::new(
1410            &cfg,
1411            vb,
1412            self.is_gptx(config)?,
1413            normal_loading_metadata,
1414            attention_mechanism,
1415        )?))
1416    }
1417    fn load_xlora(
1418        &self,
1419        config: &str,
1420        vb: ShardedVarBuilder,
1421        lora_config: &[((String, String), LoraConfig)],
1422        xlora_config: Option<XLoraConfig>,
1423        xlora_ordering: Ordering,
1424        normal_loading_metadata: NormalLoadingMetadata,
1425        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1426    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1427        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1428
1429        Ok(Box::new(xlora_models::XLoraPhi3::new(
1430            &cfg,
1431            vb,
1432            lora_config,
1433            xlora_config,
1434            xlora_ordering,
1435            self.is_gptx(config)?,
1436            normal_loading_metadata,
1437            preload_adapters,
1438        )?))
1439    }
1440    fn is_gptx(&self, _: &str) -> Result<bool> {
1441        Ok(true)
1442    }
1443    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1444        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1445
1446        Ok(Box::new(cfg))
1447    }
1448}
1449
1450impl IsqModelLoader for Phi3Loader {
1451    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1452        Ok(vec![
1453            Regex::new(r"lm_head\.(weight|bias)$")?,
1454            // Attention
1455            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1456            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1457            // MLP
1458            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1459            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1460            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1461        ])
1462    }
1463    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1464        self.isq_layer_regexes(config)
1465    }
1466}
1467
1468impl DeviceMappedModelLoader for Phi3Loader {
1469    fn mapped_max_act_size_elems(
1470        &self,
1471        config: &str,
1472        params: &AutoDeviceMapParams,
1473    ) -> Result<usize> {
1474        let AutoDeviceMapParams::Text {
1475            max_seq_len,
1476            max_batch_size,
1477        } = params
1478        else {
1479            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1480        };
1481
1482        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1483
1484        Ok(
1485            max_batch_size
1486                * cfg.num_attention_heads
1487                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1488        )
1489    }
1490    fn non_mapped_max_act_size_elems(
1491        &self,
1492        _config: &str,
1493        _params: &AutoDeviceMapParams,
1494    ) -> Result<usize> {
1495        Ok(0)
1496    }
1497
1498    fn non_mapped_size_in_bytes(
1499        &self,
1500        config: &str,
1501        dtype: DType,
1502        weight_pack_factor: usize,
1503        _matformer_config: Option<&MatformerSliceConfig>,
1504    ) -> Result<usize> {
1505        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1506
1507        let elems = {
1508            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1509            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1510            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1511                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1512            } else {
1513                0
1514            };
1515            let norm = cfg.hidden_size;
1516            embed_tokens + lm_head + norm
1517        };
1518        Ok(elems * dtype.size_in_bytes())
1519    }
1520
1521    fn layer_sizes_in_bytes(
1522        &self,
1523        config: &str,
1524        dtype: DType,
1525        weight_pack_factor: usize,
1526        _matformer_config: Option<&MatformerSliceConfig>,
1527    ) -> Result<Vec<usize>> {
1528        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1529
1530        let per_layer_elems = {
1531            let input_layernorm = cfg.hidden_size;
1532            let post_attention_layernorm = cfg.hidden_size;
1533
1534            let size_in = cfg.hidden_size;
1535            let head_dim = cfg.head_dim();
1536            let op_size =
1537                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1538            let qkv_proj = size_in * op_size / weight_pack_factor;
1539            let o_proj =
1540                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1541
1542            let h_size = cfg.hidden_size;
1543            let i_size = cfg.intermediate_size;
1544            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1545            let down_proj = h_size * i_size / weight_pack_factor;
1546
1547            input_layernorm
1548                + post_attention_layernorm
1549                + qkv_proj
1550                + o_proj
1551                + gate_up_proj
1552                + down_proj
1553        };
1554        Ok(vec![
1555            per_layer_elems * dtype.size_in_bytes();
1556            cfg.num_hidden_layers
1557        ])
1558    }
1559
1560    fn num_layers(&self, config: &str) -> Result<usize> {
1561        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1562
1563        Ok(cfg.num_hidden_layers)
1564    }
1565
1566    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1567        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569        let cfg = ModelConfigMetadata {
1570            max_seq_len: cfg.max_position_embeddings,
1571            num_layers: cfg.num_hidden_layers,
1572            hidden_size: cfg.hidden_size,
1573            num_kv_heads: cfg.num_key_value_heads,
1574            num_attn_heads: cfg.num_attention_heads,
1575            sliding_window: cfg.sliding_window,
1576            k_head_dim: cfg.head_dim(),
1577            v_head_dim: cfg.head_dim(),
1578        };
1579
1580        Ok(Box::new(cfg))
1581    }
1582}
1583
1584// ======================== Qwen2 loader
1585
1586/// [`NormalLoader`] for a Qwen 2 model.
1587///
1588/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1589pub struct Qwen2Loader;
1590
1591impl NormalModelLoader for Qwen2Loader {
1592    fn load(
1593        &self,
1594        config: &str,
1595        vb: ShardedVarBuilder,
1596        normal_loading_metadata: NormalLoadingMetadata,
1597        attention_mechanism: AttentionImplementation,
1598    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1599        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1600
1601        Ok(Box::new(models::qwen2::Model::new(
1602            &cfg,
1603            vb,
1604            self.is_gptx(config)?,
1605            normal_loading_metadata,
1606            attention_mechanism,
1607        )?))
1608    }
1609    fn load_xlora(
1610        &self,
1611        _config: &str,
1612        _vb: ShardedVarBuilder,
1613        _lora_config: &[((String, String), LoraConfig)],
1614        _xlora_config: Option<XLoraConfig>,
1615        _xlora_ordering: Ordering,
1616        _normal_loading_metadata: NormalLoadingMetadata,
1617        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1618    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1619        todo!()
1620    }
1621    fn is_gptx(&self, _: &str) -> Result<bool> {
1622        Ok(true)
1623    }
1624    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1625        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1626
1627        Ok(Box::new(cfg))
1628    }
1629}
1630
1631impl IsqModelLoader for Qwen2Loader {
1632    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1633        Ok(vec![
1634            Regex::new(r"lm_head\.(weight|bias)$")?,
1635            // Attention
1636            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1637            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1638            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1639            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1640            // MLP
1641            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1642            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1643            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1644        ])
1645    }
1646    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1647        self.isq_layer_regexes(config)
1648    }
1649}
1650
1651impl DeviceMappedModelLoader for Qwen2Loader {
1652    fn mapped_max_act_size_elems(
1653        &self,
1654        config: &str,
1655        params: &AutoDeviceMapParams,
1656    ) -> Result<usize> {
1657        let AutoDeviceMapParams::Text {
1658            max_seq_len,
1659            max_batch_size,
1660        } = params
1661        else {
1662            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1663        };
1664
1665        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1666
1667        Ok(
1668            max_batch_size
1669                * cfg.num_attention_heads
1670                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1671        )
1672    }
1673    fn non_mapped_max_act_size_elems(
1674        &self,
1675        _config: &str,
1676        _params: &AutoDeviceMapParams,
1677    ) -> Result<usize> {
1678        Ok(0)
1679    }
1680
1681    fn non_mapped_size_in_bytes(
1682        &self,
1683        config: &str,
1684        dtype: DType,
1685        weight_pack_factor: usize,
1686        _matformer_config: Option<&MatformerSliceConfig>,
1687    ) -> Result<usize> {
1688        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1689
1690        let elems = {
1691            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1692            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1693            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1694                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1695            } else {
1696                0
1697            };
1698            let norm = cfg.hidden_size;
1699            embed_tokens + lm_head + norm
1700        };
1701        Ok(elems * dtype.size_in_bytes())
1702    }
1703
1704    fn layer_sizes_in_bytes(
1705        &self,
1706        config: &str,
1707        dtype: DType,
1708        weight_pack_factor: usize,
1709        _matformer_config: Option<&MatformerSliceConfig>,
1710    ) -> Result<Vec<usize>> {
1711        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1712
1713        let per_layer_elems = {
1714            let input_layernorm = cfg.hidden_size;
1715            let post_attention_layernorm = cfg.hidden_size;
1716
1717            let size_in = cfg.hidden_size;
1718            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1719            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1720            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1721            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1722            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1723            let o_proj = size_q * size_in / weight_pack_factor;
1724
1725            let h_size = cfg.hidden_size;
1726            let i_size = cfg.intermediate_size;
1727            let gate_proj = h_size * i_size / weight_pack_factor;
1728            let up_proj = h_size * i_size / weight_pack_factor;
1729            let down_proj = i_size * h_size / weight_pack_factor;
1730
1731            input_layernorm
1732                + post_attention_layernorm
1733                + q_proj
1734                + k_proj
1735                + v_proj
1736                + o_proj
1737                + gate_proj
1738                + up_proj
1739                + down_proj
1740        };
1741        Ok(vec![
1742            per_layer_elems * dtype.size_in_bytes();
1743            cfg.num_hidden_layers
1744        ])
1745    }
1746
1747    fn num_layers(&self, config: &str) -> Result<usize> {
1748        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1749
1750        Ok(cfg.num_hidden_layers)
1751    }
1752
1753    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1754        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756        let cfg = ModelConfigMetadata {
1757            max_seq_len: cfg.max_position_embeddings,
1758            num_layers: cfg.num_hidden_layers,
1759            hidden_size: cfg.hidden_size,
1760            num_kv_heads: cfg.num_key_value_heads,
1761            num_attn_heads: cfg.num_attention_heads,
1762            sliding_window: cfg.sliding_window,
1763            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1764            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1765        };
1766
1767        Ok(Box::new(cfg))
1768    }
1769}
1770
1771// ======================== Gemma2 loader
1772
1773/// [`NormalLoader`] for a Gemma2 model.
1774///
1775/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1776pub struct Gemma2Loader;
1777
1778impl NormalModelLoader for Gemma2Loader {
1779    fn load(
1780        &self,
1781        config: &str,
1782        vb: ShardedVarBuilder,
1783        normal_loading_metadata: NormalLoadingMetadata,
1784        attention_mechanism: AttentionImplementation,
1785    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1786        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1787
1788        Ok(Box::new(models::gemma2::Model::new(
1789            &cfg,
1790            vb,
1791            self.is_gptx(config)?,
1792            normal_loading_metadata,
1793            attention_mechanism,
1794        )?))
1795    }
1796    fn load_xlora(
1797        &self,
1798        config: &str,
1799        vb: ShardedVarBuilder,
1800        lora_config: &[((String, String), LoraConfig)],
1801        xlora_config: Option<XLoraConfig>,
1802        xlora_ordering: Ordering,
1803        normal_loading_metadata: NormalLoadingMetadata,
1804        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1805    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1806        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1807
1808        Ok(Box::new(xlora_models::XLoraGemma2::new(
1809            &cfg,
1810            vb,
1811            lora_config,
1812            xlora_config,
1813            xlora_ordering,
1814            self.is_gptx(config)?,
1815            normal_loading_metadata,
1816            preload_adapters,
1817        )?))
1818    }
1819    fn is_gptx(&self, _: &str) -> Result<bool> {
1820        Ok(true)
1821    }
1822    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1823        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1824
1825        Ok(Box::new(cfg))
1826    }
1827}
1828
1829impl IsqModelLoader for Gemma2Loader {
1830    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1831        Ok(vec![
1832            Regex::new(r"lm_head\.(weight|bias)$")?,
1833            // Attention
1834            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1835            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1836            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1837            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1838            // MLP
1839            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1840            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1841            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1842        ])
1843    }
1844    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1845        self.isq_layer_regexes(config)
1846    }
1847}
1848
1849impl DeviceMappedModelLoader for Gemma2Loader {
1850    fn mapped_max_act_size_elems(
1851        &self,
1852        config: &str,
1853        params: &AutoDeviceMapParams,
1854    ) -> Result<usize> {
1855        let AutoDeviceMapParams::Text {
1856            max_seq_len,
1857            max_batch_size,
1858        } = params
1859        else {
1860            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1861        };
1862
1863        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1864
1865        Ok(
1866            max_batch_size
1867                * cfg.num_attention_heads
1868                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1869        )
1870    }
1871    fn non_mapped_max_act_size_elems(
1872        &self,
1873        _config: &str,
1874        _params: &AutoDeviceMapParams,
1875    ) -> Result<usize> {
1876        Ok(0)
1877    }
1878
1879    fn non_mapped_size_in_bytes(
1880        &self,
1881        config: &str,
1882        dtype: DType,
1883        weight_pack_factor: usize,
1884        _matformer_config: Option<&MatformerSliceConfig>,
1885    ) -> Result<usize> {
1886        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1887
1888        let elems = {
1889            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1890            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1891            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1892                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1893            } else {
1894                0
1895            };
1896            let norm = cfg.hidden_size;
1897            embed_tokens + lm_head + norm
1898        };
1899        Ok(elems * dtype.size_in_bytes())
1900    }
1901
1902    fn layer_sizes_in_bytes(
1903        &self,
1904        config: &str,
1905        dtype: DType,
1906        weight_pack_factor: usize,
1907        _matformer_config: Option<&MatformerSliceConfig>,
1908    ) -> Result<Vec<usize>> {
1909        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1910
1911        let per_layer_elems = {
1912            let input_layernorm = cfg.hidden_size;
1913            let post_attention_layernorm = cfg.hidden_size;
1914
1915            let size_in = cfg.hidden_size;
1916            let size_q = cfg.head_dim * cfg.num_attention_heads;
1917            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1918            let q_proj =
1919                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1920            let k_proj =
1921                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1922            let v_proj =
1923                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1924            let o_proj =
1925                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1926
1927            let h_size = cfg.hidden_size;
1928            let i_size = cfg.intermediate_size;
1929            let gate_proj = h_size * i_size / weight_pack_factor;
1930            let up_proj = h_size * i_size / weight_pack_factor;
1931            let down_proj = i_size * h_size / weight_pack_factor;
1932
1933            input_layernorm
1934                + post_attention_layernorm
1935                + q_proj
1936                + k_proj
1937                + v_proj
1938                + o_proj
1939                + gate_proj
1940                + up_proj
1941                + down_proj
1942        };
1943        Ok(vec![
1944            per_layer_elems * dtype.size_in_bytes();
1945            cfg.num_hidden_layers
1946        ])
1947    }
1948
1949    fn num_layers(&self, config: &str) -> Result<usize> {
1950        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1951
1952        Ok(cfg.num_hidden_layers)
1953    }
1954    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1955        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1956
1957        let cfg = ModelConfigMetadata {
1958            max_seq_len: cfg.max_position_embeddings,
1959            num_layers: cfg.num_hidden_layers,
1960            hidden_size: cfg.hidden_size,
1961            num_kv_heads: cfg.num_key_value_heads,
1962            num_attn_heads: cfg.num_attention_heads,
1963            sliding_window: None, // None to be more forgiving, some do not
1964            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1965            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1966        };
1967
1968        Ok(Box::new(cfg))
1969    }
1970}
1971
1972// ======================== Starcoder2 loader
1973
1974/// [`NormalLoader`] for a Starcoder2 model.
1975///
1976/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1977pub struct Starcoder2Loader;
1978
1979impl NormalModelLoader for Starcoder2Loader {
1980    fn load(
1981        &self,
1982        config: &str,
1983        vb: ShardedVarBuilder,
1984        normal_loading_metadata: NormalLoadingMetadata,
1985        attention_mechanism: AttentionImplementation,
1986    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1987        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1988
1989        Ok(Box::new(models::starcoder2::Model::new(
1990            &cfg,
1991            vb,
1992            self.is_gptx(config)?,
1993            normal_loading_metadata,
1994            attention_mechanism,
1995        )?))
1996    }
1997    fn load_xlora(
1998        &self,
1999        config: &str,
2000        vb: ShardedVarBuilder,
2001        lora_config: &[((String, String), LoraConfig)],
2002        xlora_config: Option<XLoraConfig>,
2003        xlora_ordering: Ordering,
2004        normal_loading_metadata: NormalLoadingMetadata,
2005        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2006    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2007        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2008
2009        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2010            &cfg,
2011            vb,
2012            lora_config,
2013            xlora_config,
2014            xlora_ordering,
2015            self.is_gptx(config)?,
2016            normal_loading_metadata,
2017            preload_adapters,
2018        )?))
2019    }
2020    fn is_gptx(&self, _: &str) -> Result<bool> {
2021        Ok(true)
2022    }
2023    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2024        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2025
2026        Ok(Box::new(cfg))
2027    }
2028}
2029
2030impl IsqModelLoader for Starcoder2Loader {
2031    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2032        Ok(vec![
2033            Regex::new(r"lm_head\.(weight|bias)$")?,
2034            // Attention
2035            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2036            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2037            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2038            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2039            // MLP
2040            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2041            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2042        ])
2043    }
2044    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2045        self.isq_layer_regexes(config)
2046    }
2047}
2048
2049impl DeviceMappedModelLoader for Starcoder2Loader {
2050    fn mapped_max_act_size_elems(
2051        &self,
2052        config: &str,
2053        params: &AutoDeviceMapParams,
2054    ) -> Result<usize> {
2055        let AutoDeviceMapParams::Text {
2056            max_seq_len,
2057            max_batch_size,
2058        } = params
2059        else {
2060            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2061        };
2062
2063        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2064
2065        Ok(
2066            max_batch_size
2067                * cfg.num_attention_heads
2068                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2069        )
2070    }
2071    fn non_mapped_max_act_size_elems(
2072        &self,
2073        _config: &str,
2074        _params: &AutoDeviceMapParams,
2075    ) -> Result<usize> {
2076        Ok(0)
2077    }
2078
2079    fn non_mapped_size_in_bytes(
2080        &self,
2081        config: &str,
2082        dtype: DType,
2083        weight_pack_factor: usize,
2084        _matformer_config: Option<&MatformerSliceConfig>,
2085    ) -> Result<usize> {
2086        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2087
2088        let elems = {
2089            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2090            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2091            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2092                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2093            } else {
2094                0
2095            };
2096            let norm = cfg.hidden_size + cfg.hidden_size;
2097            embed_tokens + lm_head + norm
2098        };
2099        Ok(elems * dtype.size_in_bytes())
2100    }
2101
2102    fn layer_sizes_in_bytes(
2103        &self,
2104        config: &str,
2105        dtype: DType,
2106        weight_pack_factor: usize,
2107        _matformer_config: Option<&MatformerSliceConfig>,
2108    ) -> Result<Vec<usize>> {
2109        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2110
2111        let per_layer_elems = {
2112            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2113            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2114
2115            let size_in = cfg.hidden_size;
2116            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2117            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2118            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2119            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2120            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2121            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2122
2123            let h_size = cfg.hidden_size;
2124            let i_size = cfg.intermediate_size;
2125            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2126            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2127
2128            input_layernorm
2129                + post_attention_layernorm
2130                + q_proj
2131                + k_proj
2132                + v_proj
2133                + o_proj
2134                + fc1
2135                + fc2
2136        };
2137        Ok(vec![
2138            per_layer_elems * dtype.size_in_bytes();
2139            cfg.num_hidden_layers
2140        ])
2141    }
2142
2143    fn num_layers(&self, config: &str) -> Result<usize> {
2144        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2145
2146        Ok(cfg.num_hidden_layers)
2147    }
2148
2149    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2150        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2151
2152        let cfg = ModelConfigMetadata {
2153            max_seq_len: cfg.max_position_embeddings,
2154            num_layers: cfg.num_hidden_layers,
2155            hidden_size: cfg.hidden_size,
2156            num_kv_heads: cfg.num_key_value_heads,
2157            num_attn_heads: cfg.num_attention_heads,
2158            sliding_window: cfg.sliding_window,
2159            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2160            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2161        };
2162
2163        Ok(Box::new(cfg))
2164    }
2165}
2166
2167// ======================== Phi3 loader
2168
2169/// [`NormalLoader`] for a Phi 3.5 MoE model.
2170///
2171/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2172pub struct Phi3_5MoELoader;
2173
2174impl NormalModelLoader for Phi3_5MoELoader {
2175    fn load(
2176        &self,
2177        config: &str,
2178        vb: ShardedVarBuilder,
2179        normal_loading_metadata: NormalLoadingMetadata,
2180        attention_mechanism: AttentionImplementation,
2181    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2182        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2183
2184        Ok(Box::new(models::phi3_5_moe::Model::new(
2185            &cfg,
2186            vb,
2187            self.is_gptx(config)?,
2188            normal_loading_metadata,
2189            attention_mechanism,
2190        )?))
2191    }
2192    fn load_xlora(
2193        &self,
2194        config: &str,
2195        vb: ShardedVarBuilder,
2196        lora_config: &[((String, String), LoraConfig)],
2197        xlora_config: Option<XLoraConfig>,
2198        xlora_ordering: Ordering,
2199        normal_loading_metadata: NormalLoadingMetadata,
2200        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2201    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2202        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2203
2204        Ok(Box::new(xlora_models::XLoraPhi3::new(
2205            &cfg,
2206            vb,
2207            lora_config,
2208            xlora_config,
2209            xlora_ordering,
2210            self.is_gptx(config)?,
2211            normal_loading_metadata,
2212            preload_adapters,
2213        )?))
2214    }
2215    fn is_gptx(&self, _: &str) -> Result<bool> {
2216        Ok(true)
2217    }
2218    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2219        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2220
2221        Ok(Box::new(cfg))
2222    }
2223}
2224
2225impl IsqModelLoader for Phi3_5MoELoader {
2226    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2227        Ok(vec![
2228            Regex::new(r"lm_head\.(weight|bias)$")?,
2229            // Attention
2230            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2231            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2232            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2233            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2234            // MLP
2235            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2236            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2237            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2238        ])
2239    }
2240    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2241        self.isq_layer_regexes(config)
2242    }
2243
2244    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2245        Ok(vec![
2246            Regex::new(r"lm_head\.(weight|bias)$")?,
2247            // MLP
2248            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2249            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2250            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2251        ])
2252    }
2253    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2254        self.isq_layer_regexes_moqe(config)
2255    }
2256}
2257
2258impl DeviceMappedModelLoader for Phi3_5MoELoader {
2259    fn mapped_max_act_size_elems(
2260        &self,
2261        config: &str,
2262        params: &AutoDeviceMapParams,
2263    ) -> Result<usize> {
2264        let AutoDeviceMapParams::Text {
2265            max_seq_len,
2266            max_batch_size,
2267        } = params
2268        else {
2269            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2270        };
2271
2272        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2273
2274        Ok(
2275            max_batch_size
2276                * cfg.num_attention_heads
2277                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2278        )
2279    }
2280    fn non_mapped_max_act_size_elems(
2281        &self,
2282        _config: &str,
2283        _params: &AutoDeviceMapParams,
2284    ) -> Result<usize> {
2285        Ok(0)
2286    }
2287
2288    fn non_mapped_size_in_bytes(
2289        &self,
2290        config: &str,
2291        dtype: DType,
2292        weight_pack_factor: usize,
2293        _matformer_config: Option<&MatformerSliceConfig>,
2294    ) -> Result<usize> {
2295        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2296
2297        let elems = {
2298            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2299            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2300            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2301                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2302            } else {
2303                0
2304            };
2305            let norm = cfg.hidden_size;
2306            embed_tokens + lm_head + norm
2307        };
2308        Ok(elems * dtype.size_in_bytes())
2309    }
2310
2311    fn layer_sizes_in_bytes(
2312        &self,
2313        config: &str,
2314        dtype: DType,
2315        weight_pack_factor: usize,
2316        _matformer_config: Option<&MatformerSliceConfig>,
2317    ) -> Result<Vec<usize>> {
2318        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2319
2320        let per_layer_elems = {
2321            let input_layernorm = cfg.hidden_size;
2322            let post_attention_layernorm = cfg.hidden_size;
2323
2324            let size_in = cfg.hidden_size;
2325            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2326            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2327            let q_proj =
2328                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2329            let k_proj =
2330                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2331            let v_proj =
2332                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2333            let o_proj =
2334                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2335
2336            let moe_block = {
2337                let gate = cfg.hidden_size * cfg.num_local_experts;
2338                // Assume quantizing weight pack factor
2339                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2340                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2341                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2342                gate + cfg.num_local_experts * w1
2343                    + cfg.num_local_experts * w2
2344                    + cfg.num_local_experts * w3
2345            };
2346
2347            input_layernorm
2348                + post_attention_layernorm
2349                + q_proj
2350                + k_proj
2351                + v_proj
2352                + o_proj
2353                + moe_block
2354        };
2355        Ok(vec![
2356            per_layer_elems * dtype.size_in_bytes();
2357            cfg.num_hidden_layers
2358        ])
2359    }
2360
2361    fn num_layers(&self, config: &str) -> Result<usize> {
2362        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2363
2364        Ok(cfg.num_hidden_layers)
2365    }
2366
2367    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2368        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2369
2370        let cfg = ModelConfigMetadata {
2371            max_seq_len: cfg.max_position_embeddings,
2372            num_layers: cfg.num_hidden_layers,
2373            hidden_size: cfg.hidden_size,
2374            num_kv_heads: cfg.num_key_value_heads,
2375            num_attn_heads: cfg.num_attention_heads,
2376            sliding_window: cfg.sliding_window,
2377            k_head_dim: cfg.head_dim(),
2378            v_head_dim: cfg.head_dim(),
2379        };
2380
2381        Ok(Box::new(cfg))
2382    }
2383}
2384
2385/// [`NormalLoader`] for a DeepSeekV2 model.
2386///
2387/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2388pub struct DeepSeekV2Loader;
2389
2390impl NormalModelLoader for DeepSeekV2Loader {
2391    fn load(
2392        &self,
2393        config: &str,
2394        vb: ShardedVarBuilder,
2395        normal_loading_metadata: NormalLoadingMetadata,
2396        attention_mechanism: AttentionImplementation,
2397    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2398        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2399
2400        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2401            &cfg,
2402            vb,
2403            self.is_gptx(config)?,
2404            normal_loading_metadata,
2405            attention_mechanism,
2406        )?))
2407    }
2408    fn load_xlora(
2409        &self,
2410        _config: &str,
2411        _vb: ShardedVarBuilder,
2412        _lora_config: &[((String, String), LoraConfig)],
2413        _xlora_config: Option<XLoraConfig>,
2414        _xlora_ordering: Ordering,
2415        _normal_loading_metadata: NormalLoadingMetadata,
2416        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2417    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2418        todo!()
2419    }
2420    fn is_gptx(&self, _: &str) -> Result<bool> {
2421        Ok(true)
2422    }
2423    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2424        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2425        Ok(Box::new(cfg))
2426    }
2427}
2428
2429impl IsqModelLoader for DeepSeekV2Loader {
2430    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2431        let mut data = vec![
2432            Regex::new(r"lm_head\.(weight|bias)$")?,
2433            // Attention
2434            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2435            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2436            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2437        ];
2438        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2439        if cfg.q_lora_rank.is_some() {
2440            data.extend(vec![
2441                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2442                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2443            ]);
2444        } else {
2445            data.push(Regex::new(
2446                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2447            )?);
2448        }
2449        for layer_idx in 0..cfg.num_hidden_layers {
2450            if cfg.n_routed_experts.is_some()
2451                && layer_idx >= cfg.first_k_dense_replace
2452                && layer_idx % cfg.moe_layer_freq == 0
2453            {
2454                for i in 0..cfg.n_routed_experts.unwrap() {
2455                    data.extend(vec![
2456                        Regex::new(&format!(
2457                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2458                        ))?,
2459                        Regex::new(&format!(
2460                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2461                        ))?,
2462                        Regex::new(&format!(
2463                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2464                        ))?,
2465                    ]);
2466                }
2467                if cfg.n_shared_experts.is_some() {
2468                    data.extend(vec![
2469                        Regex::new(&format!(
2470                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2471                        ))?,
2472                        Regex::new(&format!(
2473                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2474                        ))?,
2475                        Regex::new(&format!(
2476                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2477                        ))?,
2478                    ]);
2479                }
2480            } else {
2481                data.extend(vec![
2482                    Regex::new(&format!(
2483                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2484                    ))?,
2485                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2486                    Regex::new(&format!(
2487                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2488                    ))?,
2489                ]);
2490            };
2491        }
2492        Ok(data)
2493    }
2494    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2495        self.isq_layer_regexes(config)
2496    }
2497
2498    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2499        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2500        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2501        for layer_idx in 0..cfg.num_hidden_layers {
2502            if cfg.n_routed_experts.is_some()
2503                && layer_idx >= cfg.first_k_dense_replace
2504                && layer_idx % cfg.moe_layer_freq == 0
2505            {
2506                for i in 0..cfg.n_routed_experts.unwrap() {
2507                    data.extend(vec![
2508                        Regex::new(&format!(
2509                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2510                        ))?,
2511                        Regex::new(&format!(
2512                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2513                        ))?,
2514                        Regex::new(&format!(
2515                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2516                        ))?,
2517                    ]);
2518                }
2519                if cfg.n_shared_experts.is_some() {
2520                    data.extend(vec![
2521                        Regex::new(&format!(
2522                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2523                        ))?,
2524                        Regex::new(&format!(
2525                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2526                        ))?,
2527                        Regex::new(&format!(
2528                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2529                        ))?,
2530                    ]);
2531                }
2532            } else {
2533                data.extend(vec![
2534                    Regex::new(&format!(
2535                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2536                    ))?,
2537                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2538                    Regex::new(&format!(
2539                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2540                    ))?,
2541                ]);
2542            };
2543        }
2544        Ok(data)
2545    }
2546    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2547        self.isq_layer_regexes_moqe(config)
2548    }
2549}
2550
2551impl DeviceMappedModelLoader for DeepSeekV2Loader {
2552    fn mapped_max_act_size_elems(
2553        &self,
2554        config: &str,
2555        params: &AutoDeviceMapParams,
2556    ) -> Result<usize> {
2557        let AutoDeviceMapParams::Text {
2558            max_seq_len,
2559            max_batch_size,
2560        } = params
2561        else {
2562            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2563        };
2564
2565        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2566
2567        Ok(
2568            max_batch_size
2569                * cfg.num_attention_heads
2570                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2571        )
2572    }
2573    fn non_mapped_max_act_size_elems(
2574        &self,
2575        _config: &str,
2576        _params: &AutoDeviceMapParams,
2577    ) -> Result<usize> {
2578        Ok(0)
2579    }
2580
2581    fn non_mapped_size_in_bytes(
2582        &self,
2583        config: &str,
2584        dtype: DType,
2585        weight_pack_factor: usize,
2586        _matformer_config: Option<&MatformerSliceConfig>,
2587    ) -> Result<usize> {
2588        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2589        let elems = {
2590            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2591            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2592            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2593                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2594            } else {
2595                0
2596            };
2597            let norm = cfg.hidden_size;
2598            embed_tokens + lm_head + norm
2599        };
2600        Ok(elems * dtype.size_in_bytes())
2601    }
2602
2603    fn layer_sizes_in_bytes(
2604        &self,
2605        config: &str,
2606        dtype: DType,
2607        weight_pack_factor: usize,
2608        _matformer_config: Option<&MatformerSliceConfig>,
2609    ) -> Result<Vec<usize>> {
2610        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2611        let mut per_layer_elems = Vec::new();
2612
2613        for layer_idx in 0..cfg.num_hidden_layers {
2614            let input_layernorm = cfg.hidden_size;
2615            let post_attention_layernorm = cfg.hidden_size;
2616
2617            let q_proj = match cfg.q_lora_rank {
2618                Some(lora_rank) => {
2619                    let a = cfg.hidden_size * lora_rank;
2620                    let norm = lora_rank;
2621                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2622                    a + norm + b
2623                }
2624                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2625            };
2626            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2627                / weight_pack_factor
2628                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2629            let kv_a_layernorm = cfg.kv_lora_rank;
2630            let kv_b_proj = cfg.kv_lora_rank
2631                * cfg.num_attention_heads
2632                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2633                / weight_pack_factor;
2634            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2635                / weight_pack_factor
2636                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2637
2638            let moe_block = {
2639                let mut sum = 0;
2640                if cfg.n_routed_experts.is_some()
2641                    && layer_idx >= cfg.first_k_dense_replace
2642                    && layer_idx % cfg.moe_layer_freq == 0
2643                {
2644                    let h_size = cfg.hidden_size;
2645                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2646                        * cfg.n_routed_experts.unwrap();
2647                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2648                        * cfg.n_routed_experts.unwrap();
2649                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2650                        * cfg.n_routed_experts.unwrap();
2651                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2652                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2653                            / weight_pack_factor;
2654                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2655                            / weight_pack_factor;
2656                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2657                            / weight_pack_factor;
2658                        gate_proj + up_proj + down_proj
2659                    } else {
2660                        0
2661                    };
2662                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2663                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2664                } else {
2665                    let h_size = cfg.hidden_size;
2666                    let i_size = cfg.intermediate_size;
2667                    let gate_proj = h_size * i_size / weight_pack_factor;
2668                    let up_proj = h_size * i_size / weight_pack_factor;
2669                    let down_proj = i_size * h_size / weight_pack_factor;
2670                    sum += gate_proj + up_proj + down_proj;
2671                }
2672                sum
2673            };
2674
2675            per_layer_elems.push(
2676                input_layernorm
2677                    + post_attention_layernorm
2678                    + q_proj
2679                    + kv_a_layernorm
2680                    + kv_a_proj_with_mqa
2681                    + kv_b_proj
2682                    + o_proj
2683                    + moe_block,
2684            );
2685        }
2686
2687        Ok(per_layer_elems
2688            .into_iter()
2689            .map(|x| x * dtype.size_in_bytes())
2690            .collect())
2691    }
2692
2693    fn num_layers(&self, config: &str) -> Result<usize> {
2694        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2695        Ok(cfg.num_hidden_layers)
2696    }
2697
2698    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2699        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2700
2701        let cfg = ModelConfigMetadata {
2702            max_seq_len: cfg.max_position_embeddings,
2703            num_layers: cfg.num_hidden_layers,
2704            hidden_size: cfg.hidden_size,
2705            num_kv_heads: cfg.num_attention_heads,
2706            num_attn_heads: cfg.num_attention_heads,
2707            sliding_window: None,
2708            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2709            v_head_dim: cfg.v_head_dim,
2710        };
2711
2712        Ok(Box::new(cfg))
2713    }
2714}
2715
2716/// [`NormalLoader`] for a DeepSeekV3 model.
2717///
2718/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2719pub struct DeepSeekV3Loader;
2720
2721impl NormalModelLoader for DeepSeekV3Loader {
2722    fn load(
2723        &self,
2724        config: &str,
2725        vb: ShardedVarBuilder,
2726        normal_loading_metadata: NormalLoadingMetadata,
2727        attention_mechanism: AttentionImplementation,
2728    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2729        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2730        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2731            &cfg,
2732            vb,
2733            self.is_gptx(config)?,
2734            normal_loading_metadata,
2735            attention_mechanism,
2736        )?))
2737    }
2738    fn load_xlora(
2739        &self,
2740        _config: &str,
2741        _vb: ShardedVarBuilder,
2742        _lora_config: &[((String, String), LoraConfig)],
2743        _xlora_config: Option<XLoraConfig>,
2744        _xlora_ordering: Ordering,
2745        _normal_loading_metadata: NormalLoadingMetadata,
2746        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2747    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2748        todo!()
2749    }
2750    fn is_gptx(&self, _: &str) -> Result<bool> {
2751        Ok(true)
2752    }
2753    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2754        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2755        Ok(Box::new(cfg))
2756    }
2757}
2758
2759impl IsqModelLoader for DeepSeekV3Loader {
2760    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2761        let mut data = vec![
2762            Regex::new(r"lm_head\.(weight|bias)$")?,
2763            // Attention
2764            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2765            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2766            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2767        ];
2768        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2769        if cfg.q_lora_rank.is_some() {
2770            data.extend(vec![
2771                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2772                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2773            ]);
2774        } else {
2775            data.push(Regex::new(
2776                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2777            )?);
2778        }
2779        for layer_idx in 0..cfg.num_hidden_layers {
2780            if cfg.n_routed_experts.is_some()
2781                && layer_idx >= cfg.first_k_dense_replace
2782                && layer_idx % cfg.moe_layer_freq == 0
2783            {
2784                for i in 0..cfg.n_routed_experts.unwrap() {
2785                    data.extend(vec![
2786                        Regex::new(&format!(
2787                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2788                        ))?,
2789                        Regex::new(&format!(
2790                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2791                        ))?,
2792                        Regex::new(&format!(
2793                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2794                        ))?,
2795                    ]);
2796                }
2797                if cfg.n_shared_experts.is_some() {
2798                    data.extend(vec![
2799                        Regex::new(&format!(
2800                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2801                        ))?,
2802                        Regex::new(&format!(
2803                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2804                        ))?,
2805                        Regex::new(&format!(
2806                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2807                        ))?,
2808                    ]);
2809                }
2810            } else {
2811                data.extend(vec![
2812                    Regex::new(&format!(
2813                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2814                    ))?,
2815                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2816                    Regex::new(&format!(
2817                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2818                    ))?,
2819                ]);
2820            };
2821        }
2822        Ok(data)
2823    }
2824    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2825        self.isq_layer_regexes(config)
2826    }
2827
2828    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2829        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2830        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2831        for layer_idx in 0..cfg.num_hidden_layers {
2832            if cfg.n_routed_experts.is_some()
2833                && layer_idx >= cfg.first_k_dense_replace
2834                && layer_idx % cfg.moe_layer_freq == 0
2835            {
2836                for i in 0..cfg.n_routed_experts.unwrap() {
2837                    data.extend(vec![
2838                        Regex::new(&format!(
2839                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2840                        ))?,
2841                        Regex::new(&format!(
2842                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2843                        ))?,
2844                        Regex::new(&format!(
2845                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2846                        ))?,
2847                    ]);
2848                }
2849                if cfg.n_shared_experts.is_some() {
2850                    data.extend(vec![
2851                        Regex::new(&format!(
2852                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2853                        ))?,
2854                        Regex::new(&format!(
2855                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2856                        ))?,
2857                        Regex::new(&format!(
2858                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2859                        ))?,
2860                    ]);
2861                }
2862            } else {
2863                data.extend(vec![
2864                    Regex::new(&format!(
2865                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2866                    ))?,
2867                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2868                    Regex::new(&format!(
2869                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2870                    ))?,
2871                ]);
2872            };
2873        }
2874        Ok(data)
2875    }
2876    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2877        self.isq_layer_regexes_moqe(config)
2878    }
2879}
2880
2881impl DeviceMappedModelLoader for DeepSeekV3Loader {
2882    fn mapped_max_act_size_elems(
2883        &self,
2884        config: &str,
2885        params: &AutoDeviceMapParams,
2886    ) -> Result<usize> {
2887        let AutoDeviceMapParams::Text {
2888            max_seq_len,
2889            max_batch_size,
2890        } = params
2891        else {
2892            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2893        };
2894
2895        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2896
2897        Ok(
2898            max_batch_size
2899                * cfg.num_attention_heads
2900                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2901        )
2902    }
2903    fn non_mapped_max_act_size_elems(
2904        &self,
2905        _config: &str,
2906        _params: &AutoDeviceMapParams,
2907    ) -> Result<usize> {
2908        Ok(0)
2909    }
2910
2911    fn non_mapped_size_in_bytes(
2912        &self,
2913        config: &str,
2914        dtype: DType,
2915        weight_pack_factor: usize,
2916        _matformer_config: Option<&MatformerSliceConfig>,
2917    ) -> Result<usize> {
2918        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2919        let elems = {
2920            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2921            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2922            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2923                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2924            } else {
2925                0
2926            };
2927            let norm = cfg.hidden_size;
2928            embed_tokens + lm_head + norm
2929        };
2930        Ok(elems * dtype.size_in_bytes())
2931    }
2932
2933    fn layer_sizes_in_bytes(
2934        &self,
2935        config: &str,
2936        dtype: DType,
2937        weight_pack_factor: usize,
2938        _matformer_config: Option<&MatformerSliceConfig>,
2939    ) -> Result<Vec<usize>> {
2940        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2941        let mut per_layer_elems = Vec::new();
2942
2943        for layer_idx in 0..cfg.num_hidden_layers {
2944            let input_layernorm = cfg.hidden_size;
2945            let post_attention_layernorm = cfg.hidden_size;
2946
2947            let q_proj = match cfg.q_lora_rank {
2948                Some(lora_rank) => {
2949                    let a = cfg.hidden_size * lora_rank;
2950                    let norm = lora_rank;
2951                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2952                    a + norm + b
2953                }
2954                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2955            };
2956            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2957                / weight_pack_factor
2958                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2959            let kv_a_layernorm = cfg.kv_lora_rank;
2960            let kv_b_proj = cfg.kv_lora_rank
2961                * cfg.num_attention_heads
2962                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2963                / weight_pack_factor;
2964            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2965                / weight_pack_factor
2966                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2967
2968            let moe_block = {
2969                let mut sum = 0;
2970                if cfg.n_routed_experts.is_some()
2971                    && layer_idx >= cfg.first_k_dense_replace
2972                    && layer_idx % cfg.moe_layer_freq == 0
2973                {
2974                    let h_size = cfg.hidden_size;
2975                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2976                        * cfg.n_routed_experts.unwrap();
2977                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2978                        * cfg.n_routed_experts.unwrap();
2979                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2980                        * cfg.n_routed_experts.unwrap();
2981                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2982                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2983                            / weight_pack_factor;
2984                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2985                            / weight_pack_factor;
2986                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2987                            / weight_pack_factor;
2988                        gate_proj + up_proj + down_proj
2989                    } else {
2990                        0
2991                    };
2992                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2993                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2994                } else {
2995                    let h_size = cfg.hidden_size;
2996                    let i_size = cfg.intermediate_size;
2997                    let gate_proj = h_size * i_size / weight_pack_factor;
2998                    let up_proj = h_size * i_size / weight_pack_factor;
2999                    let down_proj = i_size * h_size / weight_pack_factor;
3000                    sum += gate_proj + up_proj + down_proj;
3001                }
3002                sum
3003            };
3004
3005            per_layer_elems.push(
3006                input_layernorm
3007                    + post_attention_layernorm
3008                    + q_proj
3009                    + kv_a_layernorm
3010                    + kv_a_proj_with_mqa
3011                    + kv_b_proj
3012                    + o_proj
3013                    + moe_block,
3014            );
3015        }
3016
3017        Ok(per_layer_elems
3018            .into_iter()
3019            .map(|x| x * dtype.size_in_bytes())
3020            .collect())
3021    }
3022
3023    fn num_layers(&self, config: &str) -> Result<usize> {
3024        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3025        Ok(cfg.num_hidden_layers)
3026    }
3027
3028    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3029        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3030
3031        let cfg = ModelConfigMetadata {
3032            max_seq_len: cfg.max_position_embeddings,
3033            num_layers: cfg.num_hidden_layers,
3034            hidden_size: cfg.hidden_size,
3035            num_kv_heads: cfg.num_attention_heads,
3036            num_attn_heads: cfg.num_attention_heads,
3037            sliding_window: None,
3038            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3039            v_head_dim: cfg.v_head_dim,
3040        };
3041
3042        Ok(Box::new(cfg))
3043    }
3044}
3045
3046/// [`NormalLoader`] for a Qwen 3 model.
3047///
3048/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3049pub struct Qwen3Loader;
3050
3051impl NormalModelLoader for Qwen3Loader {
3052    fn load(
3053        &self,
3054        config: &str,
3055        vb: ShardedVarBuilder,
3056        normal_loading_metadata: NormalLoadingMetadata,
3057        attention_mechanism: AttentionImplementation,
3058    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3059        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3060
3061        Ok(Box::new(models::qwen3::Model::new(
3062            &cfg,
3063            vb,
3064            self.is_gptx(config)?,
3065            normal_loading_metadata,
3066            attention_mechanism,
3067        )?))
3068    }
3069    fn load_xlora(
3070        &self,
3071        _config: &str,
3072        _vb: ShardedVarBuilder,
3073        _lora_config: &[((String, String), LoraConfig)],
3074        _xlora_config: Option<XLoraConfig>,
3075        _xlora_ordering: Ordering,
3076        _normal_loading_metadata: NormalLoadingMetadata,
3077        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3078    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3079        todo!()
3080    }
3081    fn is_gptx(&self, _: &str) -> Result<bool> {
3082        Ok(true)
3083    }
3084    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3085        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3086
3087        Ok(Box::new(cfg))
3088    }
3089}
3090
3091impl IsqModelLoader for Qwen3Loader {
3092    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3093        Ok(vec![
3094            Regex::new(r"lm_head\.(weight|bias)$")?,
3095            // Attention
3096            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3097            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3098            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3099            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3100            // MLP
3101            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3102            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3103            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3104        ])
3105    }
3106    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3107        self.isq_layer_regexes(config)
3108    }
3109}
3110
3111impl DeviceMappedModelLoader for Qwen3Loader {
3112    fn mapped_max_act_size_elems(
3113        &self,
3114        config: &str,
3115        params: &AutoDeviceMapParams,
3116    ) -> Result<usize> {
3117        let AutoDeviceMapParams::Text {
3118            max_seq_len,
3119            max_batch_size,
3120        } = params
3121        else {
3122            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3123        };
3124
3125        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3126
3127        Ok(
3128            max_batch_size
3129                * cfg.num_attention_heads
3130                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3131        )
3132    }
3133    fn non_mapped_max_act_size_elems(
3134        &self,
3135        _config: &str,
3136        _params: &AutoDeviceMapParams,
3137    ) -> Result<usize> {
3138        Ok(0)
3139    }
3140
3141    fn non_mapped_size_in_bytes(
3142        &self,
3143        config: &str,
3144        dtype: DType,
3145        weight_pack_factor: usize,
3146        _matformer_config: Option<&MatformerSliceConfig>,
3147    ) -> Result<usize> {
3148        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3149        let elems = {
3150            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3151            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3152            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3153                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3154            } else {
3155                0
3156            };
3157            let norm = cfg.hidden_size;
3158            embed_tokens + lm_head + norm
3159        };
3160        Ok(elems * dtype.size_in_bytes())
3161    }
3162
3163    fn layer_sizes_in_bytes(
3164        &self,
3165        config: &str,
3166        dtype: DType,
3167        weight_pack_factor: usize,
3168        _matformer_config: Option<&MatformerSliceConfig>,
3169    ) -> Result<Vec<usize>> {
3170        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3171        let per_layer_elems = {
3172            let input_layernorm = cfg.hidden_size;
3173            let post_attention_layernorm = cfg.hidden_size;
3174
3175            let size_in = cfg.hidden_size;
3176            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3177            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3178            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3179            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3180            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3181            let o_proj = size_q * size_in / weight_pack_factor;
3182
3183            let h_size = cfg.hidden_size;
3184            let i_size = cfg.intermediate_size;
3185            let gate_proj = h_size * i_size / weight_pack_factor;
3186            let up_proj = h_size * i_size / weight_pack_factor;
3187            let down_proj = i_size * h_size / weight_pack_factor;
3188
3189            let q_norm = cfg.head_dim();
3190            let k_norm = cfg.head_dim();
3191
3192            input_layernorm
3193                + post_attention_layernorm
3194                + q_proj
3195                + k_proj
3196                + v_proj
3197                + o_proj
3198                + gate_proj
3199                + up_proj
3200                + down_proj
3201                + q_norm
3202                + k_norm
3203        };
3204        Ok(vec![
3205            per_layer_elems * dtype.size_in_bytes();
3206            cfg.num_hidden_layers
3207        ])
3208    }
3209
3210    fn num_layers(&self, config: &str) -> Result<usize> {
3211        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3212        Ok(cfg.num_hidden_layers)
3213    }
3214
3215    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3216        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3217
3218        let cfg = ModelConfigMetadata {
3219            max_seq_len: cfg.max_position_embeddings,
3220            num_layers: cfg.num_hidden_layers,
3221            hidden_size: cfg.hidden_size,
3222            num_kv_heads: cfg.num_key_value_heads,
3223            num_attn_heads: cfg.num_attention_heads,
3224            sliding_window: cfg.sliding_window,
3225            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3226            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3227        };
3228
3229        Ok(Box::new(cfg))
3230    }
3231}
3232
3233/// [`NormalLoader`] for a GLM 4 model.
3234///
3235/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3236pub struct GLM4Loader;
3237
3238impl NormalModelLoader for GLM4Loader {
3239    fn load(
3240        &self,
3241        config: &str,
3242        vb: ShardedVarBuilder,
3243        normal_loading_metadata: NormalLoadingMetadata,
3244        attention_mechanism: AttentionImplementation,
3245    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3246        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3247
3248        Ok(Box::new(models::glm4::Model::new(
3249            &cfg,
3250            vb,
3251            self.is_gptx(config)?,
3252            normal_loading_metadata,
3253            attention_mechanism,
3254        )?))
3255    }
3256    fn load_xlora(
3257        &self,
3258        _config: &str,
3259        _vb: ShardedVarBuilder,
3260        _lora_config: &[((String, String), LoraConfig)],
3261        _xlora_config: Option<XLoraConfig>,
3262        _xlora_ordering: Ordering,
3263        _normal_loading_metadata: NormalLoadingMetadata,
3264        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3265    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3266        todo!()
3267    }
3268    fn is_gptx(&self, _: &str) -> Result<bool> {
3269        Ok(true)
3270    }
3271    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3272        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3273
3274        Ok(Box::new(cfg))
3275    }
3276}
3277
3278impl IsqModelLoader for GLM4Loader {
3279    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3280        Ok(vec![
3281            Regex::new(r"lm_head\.(weight|bias)$")?,
3282            // Attention
3283            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3284            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3285            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3286            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3287            // MLP
3288            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3289            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3290            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3291        ])
3292    }
3293    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3294        self.isq_layer_regexes(config)
3295    }
3296}
3297
3298impl DeviceMappedModelLoader for GLM4Loader {
3299    fn mapped_max_act_size_elems(
3300        &self,
3301        config: &str,
3302        params: &AutoDeviceMapParams,
3303    ) -> Result<usize> {
3304        let AutoDeviceMapParams::Text {
3305            max_seq_len,
3306            max_batch_size,
3307        } = params
3308        else {
3309            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3310        };
3311
3312        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3313
3314        Ok(
3315            max_batch_size
3316                * cfg.num_attention_heads
3317                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3318        )
3319    }
3320    fn non_mapped_max_act_size_elems(
3321        &self,
3322        _config: &str,
3323        _params: &AutoDeviceMapParams,
3324    ) -> Result<usize> {
3325        Ok(0)
3326    }
3327
3328    fn non_mapped_size_in_bytes(
3329        &self,
3330        config: &str,
3331        dtype: DType,
3332        weight_pack_factor: usize,
3333        _matformer_config: Option<&MatformerSliceConfig>,
3334    ) -> Result<usize> {
3335        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3336        let elems = {
3337            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3338            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3339            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3340                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3341            } else {
3342                0
3343            };
3344            let norm = cfg.hidden_size;
3345            embed_tokens + lm_head + norm
3346        };
3347        Ok(elems * dtype.size_in_bytes())
3348    }
3349
3350    fn layer_sizes_in_bytes(
3351        &self,
3352        config: &str,
3353        dtype: DType,
3354        weight_pack_factor: usize,
3355        _matformer_config: Option<&MatformerSliceConfig>,
3356    ) -> Result<Vec<usize>> {
3357        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358        let per_layer_elems = {
3359            let input_layernorm = cfg.hidden_size;
3360            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3361
3362            let size_in = cfg.hidden_size;
3363            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3364            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3365            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3366            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3367            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3368            let o_proj = size_q * size_in / weight_pack_factor;
3369
3370            let h_size = cfg.hidden_size;
3371            let i_size = cfg.intermediate_size;
3372            let gate_proj = h_size * i_size / weight_pack_factor;
3373            let up_proj = h_size * i_size / weight_pack_factor;
3374            let down_proj = i_size * h_size / weight_pack_factor;
3375
3376            input_layernorm
3377                + post_attention_layernorm
3378                + q_proj
3379                + k_proj
3380                + v_proj
3381                + o_proj
3382                + gate_proj
3383                + up_proj
3384                + down_proj
3385        };
3386        Ok(vec![
3387            per_layer_elems * dtype.size_in_bytes();
3388            cfg.num_hidden_layers
3389        ])
3390    }
3391
3392    fn num_layers(&self, config: &str) -> Result<usize> {
3393        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3394        Ok(cfg.num_hidden_layers)
3395    }
3396
3397    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3398        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3399
3400        let cfg = ModelConfigMetadata {
3401            max_seq_len: cfg.max_position_embeddings,
3402            num_layers: cfg.num_hidden_layers,
3403            hidden_size: cfg.hidden_size,
3404            num_kv_heads: cfg.num_key_value_heads,
3405            num_attn_heads: cfg.num_attention_heads,
3406            sliding_window: cfg.sliding_window,
3407            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3408            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3409        };
3410
3411        Ok(Box::new(cfg))
3412    }
3413}
3414
3415/// [`NormalLoader`] for a Qwen 3 MoE model.
3416///
3417/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3418pub struct Qwen3MoELoader;
3419
3420impl NormalModelLoader for Qwen3MoELoader {
3421    fn load(
3422        &self,
3423        config: &str,
3424        vb: ShardedVarBuilder,
3425        normal_loading_metadata: NormalLoadingMetadata,
3426        attention_mechanism: AttentionImplementation,
3427    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3428        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3429
3430        Ok(Box::new(models::qwen3_moe::Model::new(
3431            &cfg,
3432            vb,
3433            self.is_gptx(config)?,
3434            normal_loading_metadata,
3435            attention_mechanism,
3436        )?))
3437    }
3438    fn load_xlora(
3439        &self,
3440        _config: &str,
3441        _vb: ShardedVarBuilder,
3442        _lora_config: &[((String, String), LoraConfig)],
3443        _xlora_config: Option<XLoraConfig>,
3444        _xlora_ordering: Ordering,
3445        _normal_loading_metadata: NormalLoadingMetadata,
3446        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3447    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3448        todo!()
3449    }
3450    fn is_gptx(&self, _: &str) -> Result<bool> {
3451        Ok(true)
3452    }
3453    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3454        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3455
3456        Ok(Box::new(cfg))
3457    }
3458}
3459
3460impl IsqModelLoader for Qwen3MoELoader {
3461    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3462        Ok(vec![
3463            Regex::new(r"lm_head\.(weight|bias)$")?,
3464            // Attention
3465            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3466            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3467            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3468            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3469            // MLP
3470            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3471            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3472            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3473            // MLP MoE
3474            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3475            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3476            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3477        ])
3478    }
3479    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3480        self.isq_layer_regexes(config)
3481    }
3482    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3483        self.isq_layer_regexes_moqe(config)
3484    }
3485}
3486
3487impl DeviceMappedModelLoader for Qwen3MoELoader {
3488    fn mapped_max_act_size_elems(
3489        &self,
3490        config: &str,
3491        params: &AutoDeviceMapParams,
3492    ) -> Result<usize> {
3493        let AutoDeviceMapParams::Text {
3494            max_seq_len,
3495            max_batch_size,
3496        } = params
3497        else {
3498            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3499        };
3500
3501        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3502
3503        Ok(
3504            max_batch_size
3505                * cfg.num_attention_heads
3506                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3507        )
3508    }
3509    fn non_mapped_max_act_size_elems(
3510        &self,
3511        _config: &str,
3512        _params: &AutoDeviceMapParams,
3513    ) -> Result<usize> {
3514        Ok(0)
3515    }
3516
3517    fn non_mapped_size_in_bytes(
3518        &self,
3519        config: &str,
3520        dtype: DType,
3521        weight_pack_factor: usize,
3522        _matformer_config: Option<&MatformerSliceConfig>,
3523    ) -> Result<usize> {
3524        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3525        let elems = {
3526            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3527            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3528            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3529                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3530            } else {
3531                0
3532            };
3533            let norm = cfg.hidden_size;
3534            embed_tokens + lm_head + norm
3535        };
3536        Ok(elems * dtype.size_in_bytes())
3537    }
3538
3539    fn layer_sizes_in_bytes(
3540        &self,
3541        config: &str,
3542        dtype: DType,
3543        weight_pack_factor: usize,
3544        _matformer_config: Option<&MatformerSliceConfig>,
3545    ) -> Result<Vec<usize>> {
3546        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3547
3548        let mut layer_sizes_in_bytes = Vec::new();
3549        for layer_idx in 0..cfg.num_hidden_layers {
3550            let input_layernorm = cfg.hidden_size;
3551            let post_attention_layernorm = cfg.hidden_size;
3552
3553            let size_in = cfg.hidden_size;
3554            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3555            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3556            let q_proj = size_in * size_q / weight_pack_factor;
3557            let k_proj = size_in * size_kv / weight_pack_factor;
3558            let v_proj = size_in * size_kv / weight_pack_factor;
3559            let o_proj = size_q * size_in / weight_pack_factor;
3560
3561            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3562                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3563            {
3564                let gate_size = cfg.hidden_size * cfg.num_experts;
3565                let expert_size = {
3566                    let h_size = cfg.hidden_size;
3567                    let i_size = cfg.moe_intermediate_size;
3568                    let gate_proj = h_size * i_size / weight_pack_factor;
3569                    let up_proj = h_size * i_size / weight_pack_factor;
3570                    let down_proj = i_size * h_size / weight_pack_factor;
3571                    gate_proj + up_proj + down_proj
3572                };
3573                expert_size * cfg.num_experts + gate_size
3574            } else {
3575                let h_size = cfg.hidden_size;
3576                let i_size = cfg.intermediate_size;
3577                let gate_proj = h_size * i_size / weight_pack_factor;
3578                let up_proj = h_size * i_size / weight_pack_factor;
3579                let down_proj = i_size * h_size / weight_pack_factor;
3580                gate_proj + up_proj + down_proj
3581            };
3582
3583            let q_norm = cfg.head_dim();
3584            let k_norm = cfg.head_dim();
3585
3586            let size_elems = input_layernorm
3587                + post_attention_layernorm
3588                + q_proj
3589                + k_proj
3590                + v_proj
3591                + o_proj
3592                + mlp_size
3593                + q_norm
3594                + k_norm;
3595
3596            let size_in_bytes = size_elems * dtype.size_in_bytes();
3597            layer_sizes_in_bytes.push(size_in_bytes);
3598        }
3599
3600        Ok(layer_sizes_in_bytes)
3601    }
3602
3603    fn num_layers(&self, config: &str) -> Result<usize> {
3604        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3605        Ok(cfg.num_hidden_layers)
3606    }
3607
3608    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3609        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3610
3611        let cfg = ModelConfigMetadata {
3612            max_seq_len: cfg.max_position_embeddings,
3613            num_layers: cfg.num_hidden_layers,
3614            hidden_size: cfg.hidden_size,
3615            num_kv_heads: cfg.num_key_value_heads,
3616            num_attn_heads: cfg.num_attention_heads,
3617            sliding_window: cfg.sliding_window,
3618            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3619            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3620        };
3621
3622        Ok(Box::new(cfg))
3623    }
3624}
3625
3626// ======================== SmolLm3 loader
3627
3628/// [`NormalLoader`] for a SmolLm3 model.
3629///
3630/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3631pub struct SmolLm3Loader;
3632
3633impl NormalModelLoader for SmolLm3Loader {
3634    fn load(
3635        &self,
3636        config: &str,
3637        vb: ShardedVarBuilder,
3638        normal_loading_metadata: NormalLoadingMetadata,
3639        attention_mechanism: AttentionImplementation,
3640    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3641        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3642
3643        Ok(Box::new(models::smollm3::SmolLm3::new(
3644            &cfg,
3645            vb,
3646            self.is_gptx(config)?,
3647            normal_loading_metadata,
3648            attention_mechanism,
3649        )?))
3650    }
3651    fn load_xlora(
3652        &self,
3653        _config: &str,
3654        _vb: ShardedVarBuilder,
3655        _lora_config: &[((String, String), LoraConfig)],
3656        _xlora_config: Option<XLoraConfig>,
3657        _xlora_ordering: Ordering,
3658        _normal_loading_metadata: NormalLoadingMetadata,
3659        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3660    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3661        todo!()
3662    }
3663    fn is_gptx(&self, _: &str) -> Result<bool> {
3664        Ok(true)
3665    }
3666    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3667        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3668        Ok(Box::new(cfg))
3669    }
3670}
3671
3672impl IsqModelLoader for SmolLm3Loader {
3673    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3674        Ok(vec![
3675            Regex::new(r"lm_head\.(weight|bias)$")?,
3676            // Attention
3677            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3678            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3679            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3680            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3681            // MLP
3682            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3683            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3684            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3685        ])
3686    }
3687    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3688        self.isq_layer_regexes(config)
3689    }
3690}
3691
3692impl DeviceMappedModelLoader for SmolLm3Loader {
3693    fn mapped_max_act_size_elems(
3694        &self,
3695        config: &str,
3696        params: &AutoDeviceMapParams,
3697    ) -> Result<usize> {
3698        let AutoDeviceMapParams::Text {
3699            max_seq_len,
3700            max_batch_size,
3701        } = params
3702        else {
3703            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3704        };
3705
3706        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3707
3708        Ok(
3709            max_batch_size
3710                * cfg.num_attention_heads
3711                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3712        )
3713    }
3714    fn non_mapped_max_act_size_elems(
3715        &self,
3716        _config: &str,
3717        _params: &AutoDeviceMapParams,
3718    ) -> Result<usize> {
3719        Ok(0)
3720    }
3721
3722    fn non_mapped_size_in_bytes(
3723        &self,
3724        config: &str,
3725        dtype: DType,
3726        weight_pack_factor: usize,
3727        _matformer_config: Option<&MatformerSliceConfig>,
3728    ) -> Result<usize> {
3729        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3730
3731        let elems = {
3732            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3733            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3734            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3735                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3736            } else {
3737                0
3738            };
3739            let norm = cfg.hidden_size;
3740            embed_tokens + lm_head + norm
3741        };
3742        Ok(elems * dtype.size_in_bytes())
3743    }
3744
3745    fn layer_sizes_in_bytes(
3746        &self,
3747        config: &str,
3748        dtype: DType,
3749        weight_pack_factor: usize,
3750        _matformer_config: Option<&MatformerSliceConfig>,
3751    ) -> Result<Vec<usize>> {
3752        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3753
3754        let per_layer_elems = {
3755            let input_layernorm = cfg.hidden_size;
3756            let post_attention_layernorm = cfg.hidden_size;
3757
3758            let size_in = cfg.hidden_size;
3759            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3760            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3761            let q_proj = size_in * size_q / weight_pack_factor;
3762            let k_proj = size_in * size_kv / weight_pack_factor;
3763            let v_proj = size_in * size_kv / weight_pack_factor;
3764            let o_proj = size_q * size_in / weight_pack_factor;
3765
3766            let h_size = cfg.hidden_size;
3767            let i_size = cfg.intermediate_size;
3768            let gate_proj = h_size * i_size / weight_pack_factor;
3769            let up_proj = h_size * i_size / weight_pack_factor;
3770            let down_proj = i_size * h_size / weight_pack_factor;
3771
3772            input_layernorm
3773                + post_attention_layernorm
3774                + q_proj
3775                + k_proj
3776                + v_proj
3777                + o_proj
3778                + gate_proj
3779                + up_proj
3780                + down_proj
3781        };
3782        Ok(vec![
3783            per_layer_elems * dtype.size_in_bytes();
3784            cfg.num_hidden_layers
3785        ])
3786    }
3787
3788    fn num_layers(&self, config: &str) -> Result<usize> {
3789        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3790
3791        Ok(cfg.num_hidden_layers)
3792    }
3793    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3794        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3795
3796        let cfg = ModelConfigMetadata {
3797            max_seq_len: cfg.max_position_embeddings,
3798            num_layers: cfg.num_hidden_layers,
3799            hidden_size: cfg.hidden_size,
3800            num_kv_heads: cfg.num_key_value_heads,
3801            num_attn_heads: cfg.num_attention_heads,
3802            sliding_window: None,
3803            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3804            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3805        };
3806
3807        Ok(Box::new(cfg))
3808    }
3809}