mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::matformer::MatformerSliceConfig;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "qwen3moe")]
175    Qwen3Moe,
176    #[serde(rename = "smollm3")]
177    SmolLm3,
178}
179
180// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
181impl NormalLoaderType {
182    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183        match name {
184            "MistralForCausalLM" => Ok(Self::Mistral),
185            "MixtralForCausalLM" => Ok(Self::Mixtral),
186            "GemmaForCausalLM" => Ok(Self::Gemma),
187            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
188            "PhiForCausalLM" => Ok(Self::Phi2),
189            "Phi3ForCausalLM" => Ok(Self::Phi3),
190            "LlamaForCausalLM" => Ok(Self::Llama),
191            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
192            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
193            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
194            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
195            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
196            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
197            "Glm4ForCausalLM" => Ok(Self::GLM4),
198            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
199            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
200            other => anyhow::bail!(
201                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
202            ),
203        }
204    }
205}
206
207impl FromStr for NormalLoaderType {
208    type Err = String;
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        match s {
211            "mistral" => Ok(Self::Mistral),
212            "gemma" => Ok(Self::Gemma),
213            "mixtral" => Ok(Self::Mixtral),
214            "llama" => Ok(Self::Llama),
215            "phi2" => Ok(Self::Phi2),
216            "phi3" => Ok(Self::Phi3),
217            "qwen2" => Ok(Self::Qwen2),
218            "gemma2" => Ok(Self::Gemma2),
219            "starcoder2" => Ok(Self::Starcoder2),
220            "phi3.5moe" => Ok(Self::Phi3_5MoE),
221            "deepseekv2" => Ok(Self::DeepSeekV2),
222            "deepseekv3" => Ok(Self::DeepSeekV3),
223            "qwen3" => Ok(Self::Qwen3),
224            "glm4" => Ok(Self::GLM4),
225            "qwen3moe" => Ok(Self::Qwen3Moe),
226            "smollm3" => Ok(Self::SmolLm3),
227            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
228        }
229    }
230}
231
232impl Display for NormalLoaderType {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            Self::Gemma => write!(f, "gemma"),
236            Self::Gemma2 => write!(f, "gemma2"),
237            Self::Llama => write!(f, "llama"),
238            Self::Mistral => write!(f, "mistral"),
239            Self::Mixtral => write!(f, "mixtral"),
240            Self::Phi2 => write!(f, "phi2"),
241            Self::Phi3 => write!(f, "phi3"),
242            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
243            Self::Qwen2 => write!(f, "qwen2"),
244            Self::Starcoder2 => write!(f, "starcoder2"),
245            Self::DeepSeekV2 => write!(f, "deepseekv2"),
246            Self::DeepSeekV3 => write!(f, "deepseekv3"),
247            Self::Qwen3 => write!(f, "qwen3"),
248            Self::GLM4 => write!(f, "glm4"),
249            Self::Qwen3Moe => write!(f, "qwen3moe"),
250            Self::SmolLm3 => write!(f, "smollm3"),
251        }
252    }
253}
254
255macro_rules! bias_if {
256    ($cond:expr, $size:expr) => {
257        if $cond {
258            $size
259        } else {
260            0
261        }
262    };
263}
264
265/// Load a model based on the Hugging Face Transformers -CausalLM model class
266pub struct AutoNormalLoader;
267
268#[derive(Deserialize)]
269struct AutoNormalLoaderConfig {
270    architectures: Vec<String>,
271}
272
273impl AutoNormalLoader {
274    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
275        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
276        if auto_cfg.architectures.len() != 1 {
277            anyhow::bail!("Expected to have one name for `architectures` config field.")
278        }
279
280        let name = &auto_cfg.architectures[0];
281
282        let tp = NormalLoaderType::from_causal_lm_name(name)?;
283
284        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
285
286        match tp {
287            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
288            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
289            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
290            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
291            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
292            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
293            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
294            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
295            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
296            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
297            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
298            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
299            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
300            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
301            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
302            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
303        }
304    }
305}
306
307impl NormalModelLoader for AutoNormalLoader {
308    fn load(
309        &self,
310        config: &str,
311        vb: ShardedVarBuilder,
312        normal_loading_metadata: NormalLoadingMetadata,
313        attention_mechanism: AttentionImplementation,
314    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
315        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
316    }
317    fn load_xlora(
318        &self,
319        config: &str,
320        vb: ShardedVarBuilder,
321        lora_config: &[((String, String), LoraConfig)],
322        xlora_config: Option<XLoraConfig>,
323        xlora_ordering: Ordering,
324        normal_loading_metadata: NormalLoadingMetadata,
325        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
326    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327        Self::get_loader(config)?.load_xlora(
328            config,
329            vb,
330            lora_config,
331            xlora_config,
332            xlora_ordering,
333            normal_loading_metadata,
334            preload_adapters,
335        )
336    }
337    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
338        Self::get_loader(config)?.get_config_repr(config)
339    }
340    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
341        Self::get_loader(config)?.supports_paged_attention(config)
342    }
343    fn is_gptx(&self, config: &str) -> Result<bool> {
344        Self::get_loader(config)?.is_gptx(config)
345    }
346}
347
348impl IsqModelLoader for AutoNormalLoader {
349    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
350        Self::get_loader(config)?.immediate_isq_predicates(config)
351    }
352    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
353        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
354    }
355    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
356        Self::get_loader(config)?.isq_layer_regexes(config)
357    }
358    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
360    }
361}
362
363impl DeviceMappedModelLoader for AutoNormalLoader {
364    fn non_mapped_size_in_bytes(
365        &self,
366        config: &str,
367        dtype: DType,
368        weight_pack_factor: usize,
369        _matformer_config: Option<&MatformerSliceConfig>,
370    ) -> Result<usize> {
371        Self::get_loader(config)?.non_mapped_size_in_bytes(
372            config,
373            dtype,
374            weight_pack_factor,
375            _matformer_config,
376        )
377    }
378    fn num_layers(&self, config: &str) -> Result<usize> {
379        Self::get_loader(config)?.num_layers(config)
380    }
381    fn layer_sizes_in_bytes(
382        &self,
383        config: &str,
384        dtype: DType,
385        weight_pack_factor: usize,
386        _matformer_config: Option<&MatformerSliceConfig>,
387    ) -> Result<Vec<usize>> {
388        Self::get_loader(config)?.layer_sizes_in_bytes(
389            config,
390            dtype,
391            weight_pack_factor,
392            _matformer_config,
393        )
394    }
395    fn mapped_max_act_size_elems(
396        &self,
397        config: &str,
398        params: &super::AutoDeviceMapParams,
399        prompt_chunksize: usize,
400    ) -> Result<usize> {
401        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
402    }
403    fn non_mapped_max_act_size_elems(
404        &self,
405        _config: &str,
406        _params: &AutoDeviceMapParams,
407    ) -> Result<usize> {
408        Ok(0)
409    }
410    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
411        Self::get_loader(config)?.model_config(config)
412    }
413}
414
415// ======================== Mistral loader
416
417pub struct MistralLoader;
418
419impl NormalModelLoader for MistralLoader {
420    fn load(
421        &self,
422        config: &str,
423        vb: ShardedVarBuilder,
424        normal_loading_metadata: NormalLoadingMetadata,
425        attention_mechanism: AttentionImplementation,
426    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
427        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
428        Ok(Box::new(models::mistral::Model::new(
429            &cfg,
430            vb,
431            self.is_gptx(config)?,
432            normal_loading_metadata,
433            attention_mechanism,
434        )?))
435    }
436    fn load_xlora(
437        &self,
438        config: &str,
439        vb: ShardedVarBuilder,
440        lora_config: &[((String, String), LoraConfig)],
441        xlora_config: Option<XLoraConfig>,
442        xlora_ordering: Ordering,
443        normal_loading_metadata: NormalLoadingMetadata,
444        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
446        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
447        Ok(Box::new(xlora_models::XLoraMistral::new(
448            &cfg,
449            vb,
450            lora_config,
451            xlora_config,
452            xlora_ordering,
453            self.is_gptx(config)?,
454            normal_loading_metadata,
455            preload_adapters,
456        )?))
457    }
458    fn is_gptx(&self, _: &str) -> Result<bool> {
459        Ok(true)
460    }
461    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
462        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
463        Ok(Box::new(cfg))
464    }
465}
466
467impl IsqModelLoader for MistralLoader {
468    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
469        Ok(vec![
470            Regex::new(r"lm_head\.(weight|bias)$")?,
471            // Attention
472            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
473            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
474            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
475            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
476            // MLP
477            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
478            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
479            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
480        ])
481    }
482    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
483        self.isq_layer_regexes(config)
484    }
485}
486
487impl DeviceMappedModelLoader for MistralLoader {
488    fn mapped_max_act_size_elems(
489        &self,
490        config: &str,
491        params: &AutoDeviceMapParams,
492        prompt_chunksize: usize,
493    ) -> Result<usize> {
494        let AutoDeviceMapParams::Text {
495            max_seq_len: _,
496            max_batch_size,
497        } = params
498        else {
499            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
500        };
501
502        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
503
504        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
505    }
506    fn non_mapped_max_act_size_elems(
507        &self,
508        _config: &str,
509        _params: &AutoDeviceMapParams,
510    ) -> Result<usize> {
511        Ok(0)
512    }
513
514    fn non_mapped_size_in_bytes(
515        &self,
516        config: &str,
517        dtype: DType,
518        weight_pack_factor: usize,
519        _matformer_config: Option<&MatformerSliceConfig>,
520    ) -> Result<usize> {
521        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
522
523        let elems = {
524            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
525            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
526            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
527                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
528            } else {
529                0
530            };
531            let norm = cfg.hidden_size;
532            embed_tokens + lm_head + norm
533        };
534        Ok(elems * dtype.size_in_bytes())
535    }
536
537    fn layer_sizes_in_bytes(
538        &self,
539        config: &str,
540        dtype: DType,
541        weight_pack_factor: usize,
542        _matformer_config: Option<&MatformerSliceConfig>,
543    ) -> Result<Vec<usize>> {
544        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
545
546        let per_layer_elems = {
547            let input_layernorm = cfg.hidden_size;
548            let post_attention_layernorm = cfg.hidden_size;
549
550            let size_in = cfg.hidden_size;
551            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
552            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
553            let q_proj = size_in * size_q / weight_pack_factor;
554            let k_proj = size_in * size_kv / weight_pack_factor;
555            let v_proj = size_in * size_kv / weight_pack_factor;
556            let o_proj = size_q * size_in / weight_pack_factor;
557
558            let h_size = cfg.hidden_size;
559            let i_size = cfg.intermediate_size;
560            let gate_proj = h_size * i_size / weight_pack_factor;
561            let up_proj = h_size * i_size / weight_pack_factor;
562            let down_proj = i_size * h_size / weight_pack_factor;
563
564            input_layernorm
565                + post_attention_layernorm
566                + q_proj
567                + k_proj
568                + v_proj
569                + o_proj
570                + gate_proj
571                + up_proj
572                + down_proj
573        };
574        Ok(vec![
575            per_layer_elems * dtype.size_in_bytes();
576            cfg.num_hidden_layers
577        ])
578    }
579
580    fn num_layers(&self, config: &str) -> Result<usize> {
581        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
582        Ok(cfg.num_hidden_layers)
583    }
584
585    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
586        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
587
588        let cfg = ModelConfigMetadata {
589            max_seq_len: cfg.max_position_embeddings,
590            num_layers: cfg.num_hidden_layers,
591            hidden_size: cfg.hidden_size,
592            num_kv_heads: cfg.num_key_value_heads,
593            num_attn_heads: cfg.num_attention_heads,
594            sliding_window: cfg.sliding_window,
595            k_head_dim: cfg.head_dim(),
596            v_head_dim: cfg.head_dim(),
597        };
598
599        Ok(Box::new(cfg))
600    }
601}
602
603// ======================== Gemma loader
604
605/// [`NormalLoader`] for a Gemma model.
606///
607/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
608pub struct GemmaLoader;
609
610impl NormalModelLoader for GemmaLoader {
611    fn load(
612        &self,
613        config: &str,
614        vb: ShardedVarBuilder,
615        normal_loading_metadata: NormalLoadingMetadata,
616        attention_mechanism: AttentionImplementation,
617    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
618        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
619
620        Ok(Box::new(models::gemma::Model::new(
621            &cfg,
622            vb,
623            self.is_gptx(config)?,
624            normal_loading_metadata,
625            attention_mechanism,
626        )?))
627    }
628    fn load_xlora(
629        &self,
630        config: &str,
631        vb: ShardedVarBuilder,
632        lora_config: &[((String, String), LoraConfig)],
633        xlora_config: Option<XLoraConfig>,
634        xlora_ordering: Ordering,
635        normal_loading_metadata: NormalLoadingMetadata,
636        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
637    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
638        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
639
640        Ok(Box::new(xlora_models::XLoraGemma::new(
641            &cfg,
642            vb,
643            lora_config,
644            xlora_config,
645            xlora_ordering,
646            self.is_gptx(config)?,
647            normal_loading_metadata,
648            preload_adapters,
649        )?))
650    }
651    fn is_gptx(&self, _: &str) -> Result<bool> {
652        Ok(true)
653    }
654    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
655        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
656        Ok(Box::new(cfg))
657    }
658}
659
660impl IsqModelLoader for GemmaLoader {
661    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
662        Ok(vec![
663            Regex::new(r"lm_head\.(weight|bias)$")?,
664            // Attention
665            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
666            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
667            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
668            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
669            // MLP
670            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
671            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
672            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
673        ])
674    }
675    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
676        self.isq_layer_regexes(config)
677    }
678}
679
680impl DeviceMappedModelLoader for GemmaLoader {
681    fn mapped_max_act_size_elems(
682        &self,
683        config: &str,
684        params: &AutoDeviceMapParams,
685        prompt_chunksize: usize,
686    ) -> Result<usize> {
687        let AutoDeviceMapParams::Text {
688            max_seq_len: _,
689            max_batch_size,
690        } = params
691        else {
692            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
693        };
694
695        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
696
697        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
698    }
699    fn non_mapped_max_act_size_elems(
700        &self,
701        _config: &str,
702        _params: &AutoDeviceMapParams,
703    ) -> Result<usize> {
704        Ok(0)
705    }
706
707    fn non_mapped_size_in_bytes(
708        &self,
709        config: &str,
710        dtype: DType,
711        weight_pack_factor: usize,
712        _matformer_config: Option<&MatformerSliceConfig>,
713    ) -> Result<usize> {
714        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
715
716        let elems = {
717            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
718            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
719            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
720                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
721            } else {
722                0
723            };
724            let norm = cfg.hidden_size;
725            embed_tokens + lm_head + norm
726        };
727        Ok(elems * dtype.size_in_bytes())
728    }
729
730    fn layer_sizes_in_bytes(
731        &self,
732        config: &str,
733        dtype: DType,
734        weight_pack_factor: usize,
735        _matformer_config: Option<&MatformerSliceConfig>,
736    ) -> Result<Vec<usize>> {
737        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
738
739        let per_layer_elems = {
740            let input_layernorm = cfg.hidden_size;
741            let post_attention_layernorm = cfg.hidden_size;
742
743            let size_in = cfg.hidden_size;
744            let size_q = cfg.head_dim * cfg.num_attention_heads;
745            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
746            let q_proj =
747                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
748            let k_proj =
749                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
750            let v_proj =
751                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
752            let o_proj =
753                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
754
755            let h_size = cfg.hidden_size;
756            let i_size = cfg.intermediate_size;
757            let gate_proj = h_size * i_size / weight_pack_factor;
758            let up_proj = h_size * i_size / weight_pack_factor;
759            let down_proj = i_size * h_size / weight_pack_factor;
760
761            input_layernorm
762                + post_attention_layernorm
763                + q_proj
764                + k_proj
765                + v_proj
766                + o_proj
767                + gate_proj
768                + up_proj
769                + down_proj
770        };
771        Ok(vec![
772            per_layer_elems * dtype.size_in_bytes();
773            cfg.num_hidden_layers
774        ])
775    }
776
777    fn num_layers(&self, config: &str) -> Result<usize> {
778        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
779        Ok(cfg.num_hidden_layers)
780    }
781
782    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
783        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
784
785        let cfg = ModelConfigMetadata {
786            max_seq_len: cfg.max_position_embeddings,
787            num_layers: cfg.num_hidden_layers,
788            hidden_size: cfg.hidden_size,
789            num_kv_heads: cfg.num_key_value_heads,
790            num_attn_heads: cfg.num_attention_heads,
791            sliding_window: None,
792            k_head_dim: cfg.head_dim,
793            v_head_dim: cfg.head_dim,
794        };
795
796        Ok(Box::new(cfg))
797    }
798}
799
800// ======================== Llama loader
801
802/// [`NormalLoader`] for a Llama model.
803///
804/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
805pub struct LlamaLoader;
806
807impl NormalModelLoader for LlamaLoader {
808    fn load(
809        &self,
810        config: &str,
811        vb: ShardedVarBuilder,
812        normal_loading_metadata: NormalLoadingMetadata,
813        attention_mechanism: AttentionImplementation,
814    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
815        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
816
817        Ok(Box::new(models::llama::Llama::new(
818            &cfg,
819            vb,
820            self.is_gptx(config)?,
821            normal_loading_metadata,
822            attention_mechanism,
823        )?))
824    }
825    fn load_xlora(
826        &self,
827        config: &str,
828        vb: ShardedVarBuilder,
829        lora_config: &[((String, String), LoraConfig)],
830        xlora_config: Option<XLoraConfig>,
831        xlora_ordering: Ordering,
832        normal_loading_metadata: NormalLoadingMetadata,
833        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
834    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
835        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
836
837        Ok(Box::new(xlora_models::XLoraLlama::new(
838            &cfg,
839            vb,
840            lora_config,
841            xlora_config,
842            xlora_ordering,
843            self.is_gptx(config)?,
844            normal_loading_metadata,
845            preload_adapters,
846        )?))
847    }
848    fn is_gptx(&self, _: &str) -> Result<bool> {
849        Ok(true)
850    }
851    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
852        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
853        Ok(Box::new(cfg))
854    }
855}
856
857impl IsqModelLoader for LlamaLoader {
858    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
859        Ok(vec![
860            Regex::new(r"lm_head\.(weight|bias)$")?,
861            // Attention
862            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
863            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
864            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
865            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
866            // MLP
867            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
868            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
869            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
870        ])
871    }
872    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
873        self.isq_layer_regexes(config)
874    }
875}
876
877impl DeviceMappedModelLoader for LlamaLoader {
878    fn mapped_max_act_size_elems(
879        &self,
880        config: &str,
881        params: &AutoDeviceMapParams,
882        prompt_chunksize: usize,
883    ) -> Result<usize> {
884        let AutoDeviceMapParams::Text {
885            max_seq_len: _,
886            max_batch_size,
887        } = params
888        else {
889            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
890        };
891
892        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
893
894        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
895    }
896    fn non_mapped_max_act_size_elems(
897        &self,
898        _config: &str,
899        _params: &AutoDeviceMapParams,
900    ) -> Result<usize> {
901        Ok(0)
902    }
903
904    fn non_mapped_size_in_bytes(
905        &self,
906        config: &str,
907        dtype: DType,
908        weight_pack_factor: usize,
909        _matformer_config: Option<&MatformerSliceConfig>,
910    ) -> Result<usize> {
911        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
912
913        let elems = {
914            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
915            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
916            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
917                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
918            } else {
919                0
920            };
921            let norm = cfg.hidden_size;
922            embed_tokens + lm_head + norm
923        };
924        Ok(elems * dtype.size_in_bytes())
925    }
926
927    fn layer_sizes_in_bytes(
928        &self,
929        config: &str,
930        dtype: DType,
931        weight_pack_factor: usize,
932        _matformer_config: Option<&MatformerSliceConfig>,
933    ) -> Result<Vec<usize>> {
934        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
935
936        let per_layer_elems = {
937            let input_layernorm = cfg.hidden_size;
938            let post_attention_layernorm = cfg.hidden_size;
939
940            let size_in = cfg.hidden_size;
941            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
942            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
943            let q_proj = size_in * size_q / weight_pack_factor;
944            let k_proj = size_in * size_kv / weight_pack_factor;
945            let v_proj = size_in * size_kv / weight_pack_factor;
946            let o_proj = size_q * size_in / weight_pack_factor;
947
948            let h_size = cfg.hidden_size;
949            let i_size = cfg.intermediate_size;
950            let gate_proj = h_size * i_size / weight_pack_factor;
951            let up_proj = h_size * i_size / weight_pack_factor;
952            let down_proj = i_size * h_size / weight_pack_factor;
953
954            input_layernorm
955                + post_attention_layernorm
956                + q_proj
957                + k_proj
958                + v_proj
959                + o_proj
960                + gate_proj
961                + up_proj
962                + down_proj
963        };
964        Ok(vec![
965            per_layer_elems * dtype.size_in_bytes();
966            cfg.num_hidden_layers
967        ])
968    }
969
970    fn num_layers(&self, config: &str) -> Result<usize> {
971        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
972
973        Ok(cfg.num_hidden_layers)
974    }
975    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
976        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
977
978        let cfg = ModelConfigMetadata {
979            max_seq_len: cfg.max_position_embeddings,
980            num_layers: cfg.num_hidden_layers,
981            hidden_size: cfg.hidden_size,
982            num_kv_heads: cfg.num_key_value_heads,
983            num_attn_heads: cfg.num_attention_heads,
984            sliding_window: None,
985            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
986            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
987        };
988
989        Ok(Box::new(cfg))
990    }
991}
992
993// ======================== Mixtral loader
994
995pub struct MixtralLoader;
996
997impl NormalModelLoader for MixtralLoader {
998    fn load(
999        &self,
1000        config: &str,
1001        vb: ShardedVarBuilder,
1002        normal_loading_metadata: NormalLoadingMetadata,
1003        attention_mechanism: AttentionImplementation,
1004    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1005        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1006
1007        Ok(Box::new(models::mixtral::Model::new(
1008            &cfg,
1009            vb,
1010            self.is_gptx(config)?,
1011            normal_loading_metadata,
1012            attention_mechanism,
1013        )?))
1014    }
1015    fn load_xlora(
1016        &self,
1017        config: &str,
1018        vb: ShardedVarBuilder,
1019        lora_config: &[((String, String), LoraConfig)],
1020        xlora_config: Option<XLoraConfig>,
1021        xlora_ordering: Ordering,
1022        normal_loading_metadata: NormalLoadingMetadata,
1023        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1024    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1025        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1026
1027        Ok(Box::new(xlora_models::XLoraMixtral::new(
1028            &cfg,
1029            vb,
1030            lora_config,
1031            xlora_config,
1032            xlora_ordering,
1033            self.is_gptx(config)?,
1034            normal_loading_metadata,
1035            preload_adapters,
1036        )?))
1037    }
1038    fn is_gptx(&self, _: &str) -> Result<bool> {
1039        Ok(true)
1040    }
1041    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1042        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1043
1044        Ok(Box::new(cfg))
1045    }
1046}
1047
1048impl IsqModelLoader for MixtralLoader {
1049    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1050        Ok(vec![
1051            Regex::new(r"lm_head\.(weight|bias)$")?,
1052            // Attention
1053            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1054            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1055            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1056            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1057            // Experts
1058            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1059            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1060            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1061            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1062        ])
1063    }
1064    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1065        self.isq_layer_regexes(config)
1066    }
1067}
1068
1069impl DeviceMappedModelLoader for MixtralLoader {
1070    fn mapped_max_act_size_elems(
1071        &self,
1072        config: &str,
1073        params: &AutoDeviceMapParams,
1074        prompt_chunksize: usize,
1075    ) -> Result<usize> {
1076        let AutoDeviceMapParams::Text {
1077            max_seq_len: _,
1078            max_batch_size,
1079        } = params
1080        else {
1081            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1082        };
1083
1084        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1085
1086        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1087    }
1088    fn non_mapped_max_act_size_elems(
1089        &self,
1090        _config: &str,
1091        _params: &AutoDeviceMapParams,
1092    ) -> Result<usize> {
1093        Ok(0)
1094    }
1095
1096    fn non_mapped_size_in_bytes(
1097        &self,
1098        config: &str,
1099        dtype: DType,
1100        weight_pack_factor: usize,
1101        _matformer_config: Option<&MatformerSliceConfig>,
1102    ) -> Result<usize> {
1103        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1104
1105        let elems = {
1106            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1107            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1108            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1109                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1110            } else {
1111                0
1112            };
1113            let norm = cfg.hidden_size;
1114            embed_tokens + lm_head + norm
1115        };
1116        Ok(elems * dtype.size_in_bytes())
1117    }
1118
1119    fn layer_sizes_in_bytes(
1120        &self,
1121        config: &str,
1122        dtype: DType,
1123        weight_pack_factor: usize,
1124        _matformer_config: Option<&MatformerSliceConfig>,
1125    ) -> Result<Vec<usize>> {
1126        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128        let per_layer_elems = {
1129            let input_layernorm = cfg.hidden_size;
1130            let post_attention_layernorm = cfg.hidden_size;
1131
1132            let size_in = cfg.hidden_size;
1133            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1134            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1135            let q_proj = size_in * size_q / weight_pack_factor;
1136            let k_proj = size_in * size_kv / weight_pack_factor;
1137            let v_proj = size_in * size_kv / weight_pack_factor;
1138            let o_proj = size_q * size_in / weight_pack_factor;
1139
1140            let moe_block = {
1141                let gate = cfg.hidden_size * cfg.num_local_experts;
1142                // Assume quantizing weight pack factor
1143                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1144                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1145                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1146                gate + cfg.num_local_experts * w1
1147                    + cfg.num_local_experts * w2
1148                    + cfg.num_local_experts * w3
1149            };
1150
1151            input_layernorm
1152                + post_attention_layernorm
1153                + q_proj
1154                + k_proj
1155                + v_proj
1156                + o_proj
1157                + moe_block
1158        };
1159        Ok(vec![
1160            per_layer_elems * dtype.size_in_bytes();
1161            cfg.num_hidden_layers
1162        ])
1163    }
1164
1165    fn num_layers(&self, config: &str) -> Result<usize> {
1166        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1167
1168        Ok(cfg.num_hidden_layers)
1169    }
1170
1171    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1172        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1173
1174        let cfg = ModelConfigMetadata {
1175            max_seq_len: cfg.max_position_embeddings,
1176            num_layers: cfg.num_hidden_layers,
1177            hidden_size: cfg.hidden_size,
1178            num_kv_heads: cfg.num_key_value_heads,
1179            num_attn_heads: cfg.num_attention_heads,
1180            sliding_window: cfg.sliding_window,
1181            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1182            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1183        };
1184
1185        Ok(Box::new(cfg))
1186    }
1187}
1188
1189// ======================== Phi2 loader
1190
1191/// [`NormalLoader`] for a Phi 2 model.
1192///
1193/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1194pub struct Phi2Loader;
1195
1196impl NormalModelLoader for Phi2Loader {
1197    fn load(
1198        &self,
1199        config: &str,
1200        vb: ShardedVarBuilder,
1201        normal_loading_metadata: NormalLoadingMetadata,
1202        attention_mechanism: AttentionImplementation,
1203    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1204        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1205
1206        Ok(Box::new(models::phi2::Model::new(
1207            &cfg,
1208            vb,
1209            self.is_gptx(config)?,
1210            normal_loading_metadata,
1211            attention_mechanism,
1212        )?))
1213    }
1214    fn load_xlora(
1215        &self,
1216        config: &str,
1217        vb: ShardedVarBuilder,
1218        lora_config: &[((String, String), LoraConfig)],
1219        xlora_config: Option<XLoraConfig>,
1220        xlora_ordering: Ordering,
1221        normal_loading_metadata: NormalLoadingMetadata,
1222        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1223    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1224        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1225
1226        Ok(Box::new(xlora_models::XLoraPhi2::new(
1227            &cfg,
1228            vb,
1229            lora_config,
1230            xlora_config,
1231            xlora_ordering,
1232            self.is_gptx(config)?,
1233            normal_loading_metadata,
1234            preload_adapters,
1235        )?))
1236    }
1237    fn is_gptx(&self, _: &str) -> Result<bool> {
1238        Ok(true)
1239    }
1240    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1241        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243        Ok(Box::new(cfg))
1244    }
1245}
1246
1247impl IsqModelLoader for Phi2Loader {
1248    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1249        Ok(vec![
1250            Regex::new(r"lm_head\.(weight|bias)$")?,
1251            // Attention
1252            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1253            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1254            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1255            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1256            // MLP
1257            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1258            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1259        ])
1260    }
1261    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1262        self.isq_layer_regexes(config)
1263    }
1264}
1265
1266impl DeviceMappedModelLoader for Phi2Loader {
1267    fn mapped_max_act_size_elems(
1268        &self,
1269        config: &str,
1270        params: &AutoDeviceMapParams,
1271        prompt_chunksize: usize,
1272    ) -> Result<usize> {
1273        let AutoDeviceMapParams::Text {
1274            max_seq_len: _,
1275            max_batch_size,
1276        } = params
1277        else {
1278            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1279        };
1280
1281        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1282
1283        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1284    }
1285    fn non_mapped_max_act_size_elems(
1286        &self,
1287        _config: &str,
1288        _params: &AutoDeviceMapParams,
1289    ) -> Result<usize> {
1290        Ok(0)
1291    }
1292
1293    fn non_mapped_size_in_bytes(
1294        &self,
1295        config: &str,
1296        dtype: DType,
1297        weight_pack_factor: usize,
1298        _matformer_config: Option<&MatformerSliceConfig>,
1299    ) -> Result<usize> {
1300        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1301
1302        let elems = {
1303            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1304            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1305            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1306                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1307            } else {
1308                0
1309            };
1310            let norm = cfg.hidden_size;
1311            embed_tokens + lm_head + norm
1312        };
1313        Ok(elems * dtype.size_in_bytes())
1314    }
1315
1316    fn layer_sizes_in_bytes(
1317        &self,
1318        config: &str,
1319        dtype: DType,
1320        weight_pack_factor: usize,
1321        _matformer_config: Option<&MatformerSliceConfig>,
1322    ) -> Result<Vec<usize>> {
1323        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1324
1325        let per_layer_elems = {
1326            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1327
1328            let size_in = cfg.hidden_size;
1329            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1330            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1331            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1332            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1333            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1334            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1335            let (q_norm, k_norm) = if cfg.qk_layernorm {
1336                (cfg.head_dim(), cfg.head_dim())
1337            } else {
1338                (0, 0)
1339            };
1340
1341            let h_size = cfg.hidden_size;
1342            let i_size = cfg.intermediate_size;
1343            let fc1 = h_size * i_size / weight_pack_factor;
1344            let fc2 = h_size * i_size / weight_pack_factor;
1345
1346            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1347        };
1348        Ok(vec![
1349            per_layer_elems * dtype.size_in_bytes();
1350            cfg.num_hidden_layers
1351        ])
1352    }
1353
1354    fn num_layers(&self, config: &str) -> Result<usize> {
1355        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1356
1357        Ok(cfg.num_hidden_layers)
1358    }
1359
1360    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1361        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1362
1363        let cfg = ModelConfigMetadata {
1364            max_seq_len: cfg.max_position_embeddings,
1365            num_layers: cfg.num_hidden_layers,
1366            hidden_size: cfg.hidden_size,
1367            num_kv_heads: cfg.num_key_value_heads(),
1368            num_attn_heads: cfg.num_attention_heads,
1369            sliding_window: None,
1370            k_head_dim: cfg.head_dim(),
1371            v_head_dim: cfg.head_dim(),
1372        };
1373
1374        Ok(Box::new(cfg))
1375    }
1376}
1377
1378// ======================== Phi3 loader
1379
1380/// [`NormalLoader`] for a Phi 3 model.
1381///
1382/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1383pub struct Phi3Loader;
1384
1385impl NormalModelLoader for Phi3Loader {
1386    fn load(
1387        &self,
1388        config: &str,
1389        vb: ShardedVarBuilder,
1390        normal_loading_metadata: NormalLoadingMetadata,
1391        attention_mechanism: AttentionImplementation,
1392    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1393        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1394
1395        Ok(Box::new(models::phi3::Model::new(
1396            &cfg,
1397            vb,
1398            self.is_gptx(config)?,
1399            normal_loading_metadata,
1400            attention_mechanism,
1401        )?))
1402    }
1403    fn load_xlora(
1404        &self,
1405        config: &str,
1406        vb: ShardedVarBuilder,
1407        lora_config: &[((String, String), LoraConfig)],
1408        xlora_config: Option<XLoraConfig>,
1409        xlora_ordering: Ordering,
1410        normal_loading_metadata: NormalLoadingMetadata,
1411        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1412    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1413        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1414
1415        Ok(Box::new(xlora_models::XLoraPhi3::new(
1416            &cfg,
1417            vb,
1418            lora_config,
1419            xlora_config,
1420            xlora_ordering,
1421            self.is_gptx(config)?,
1422            normal_loading_metadata,
1423            preload_adapters,
1424        )?))
1425    }
1426    fn is_gptx(&self, _: &str) -> Result<bool> {
1427        Ok(true)
1428    }
1429    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1430        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1431
1432        Ok(Box::new(cfg))
1433    }
1434}
1435
1436impl IsqModelLoader for Phi3Loader {
1437    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1438        Ok(vec![
1439            Regex::new(r"lm_head\.(weight|bias)$")?,
1440            // Attention
1441            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1442            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1443            // MLP
1444            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1445            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1446            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1447        ])
1448    }
1449    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1450        self.isq_layer_regexes(config)
1451    }
1452}
1453
1454impl DeviceMappedModelLoader for Phi3Loader {
1455    fn mapped_max_act_size_elems(
1456        &self,
1457        config: &str,
1458        params: &AutoDeviceMapParams,
1459        prompt_chunksize: usize,
1460    ) -> Result<usize> {
1461        let AutoDeviceMapParams::Text {
1462            max_seq_len: _,
1463            max_batch_size,
1464        } = params
1465        else {
1466            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1467        };
1468
1469        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1470
1471        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1472    }
1473    fn non_mapped_max_act_size_elems(
1474        &self,
1475        _config: &str,
1476        _params: &AutoDeviceMapParams,
1477    ) -> Result<usize> {
1478        Ok(0)
1479    }
1480
1481    fn non_mapped_size_in_bytes(
1482        &self,
1483        config: &str,
1484        dtype: DType,
1485        weight_pack_factor: usize,
1486        _matformer_config: Option<&MatformerSliceConfig>,
1487    ) -> Result<usize> {
1488        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1489
1490        let elems = {
1491            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1492            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1493            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1494                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1495            } else {
1496                0
1497            };
1498            let norm = cfg.hidden_size;
1499            embed_tokens + lm_head + norm
1500        };
1501        Ok(elems * dtype.size_in_bytes())
1502    }
1503
1504    fn layer_sizes_in_bytes(
1505        &self,
1506        config: &str,
1507        dtype: DType,
1508        weight_pack_factor: usize,
1509        _matformer_config: Option<&MatformerSliceConfig>,
1510    ) -> Result<Vec<usize>> {
1511        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513        let per_layer_elems = {
1514            let input_layernorm = cfg.hidden_size;
1515            let post_attention_layernorm = cfg.hidden_size;
1516
1517            let size_in = cfg.hidden_size;
1518            let head_dim = cfg.head_dim();
1519            let op_size =
1520                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1521            let qkv_proj = size_in * op_size / weight_pack_factor;
1522            let o_proj =
1523                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1524
1525            let h_size = cfg.hidden_size;
1526            let i_size = cfg.intermediate_size;
1527            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1528            let down_proj = h_size * i_size / weight_pack_factor;
1529
1530            input_layernorm
1531                + post_attention_layernorm
1532                + qkv_proj
1533                + o_proj
1534                + gate_up_proj
1535                + down_proj
1536        };
1537        Ok(vec![
1538            per_layer_elems * dtype.size_in_bytes();
1539            cfg.num_hidden_layers
1540        ])
1541    }
1542
1543    fn num_layers(&self, config: &str) -> Result<usize> {
1544        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1545
1546        Ok(cfg.num_hidden_layers)
1547    }
1548
1549    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1550        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1551
1552        let cfg = ModelConfigMetadata {
1553            max_seq_len: cfg.max_position_embeddings,
1554            num_layers: cfg.num_hidden_layers,
1555            hidden_size: cfg.hidden_size,
1556            num_kv_heads: cfg.num_key_value_heads,
1557            num_attn_heads: cfg.num_attention_heads,
1558            sliding_window: cfg.sliding_window,
1559            k_head_dim: cfg.head_dim(),
1560            v_head_dim: cfg.head_dim(),
1561        };
1562
1563        Ok(Box::new(cfg))
1564    }
1565}
1566
1567// ======================== Qwen2 loader
1568
1569/// [`NormalLoader`] for a Qwen 2 model.
1570///
1571/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1572pub struct Qwen2Loader;
1573
1574impl NormalModelLoader for Qwen2Loader {
1575    fn load(
1576        &self,
1577        config: &str,
1578        vb: ShardedVarBuilder,
1579        normal_loading_metadata: NormalLoadingMetadata,
1580        attention_mechanism: AttentionImplementation,
1581    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1582        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1583
1584        Ok(Box::new(models::qwen2::Model::new(
1585            &cfg,
1586            vb,
1587            self.is_gptx(config)?,
1588            normal_loading_metadata,
1589            attention_mechanism,
1590        )?))
1591    }
1592    fn load_xlora(
1593        &self,
1594        _config: &str,
1595        _vb: ShardedVarBuilder,
1596        _lora_config: &[((String, String), LoraConfig)],
1597        _xlora_config: Option<XLoraConfig>,
1598        _xlora_ordering: Ordering,
1599        _normal_loading_metadata: NormalLoadingMetadata,
1600        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1601    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1602        todo!()
1603    }
1604    fn is_gptx(&self, _: &str) -> Result<bool> {
1605        Ok(true)
1606    }
1607    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1608        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1609
1610        Ok(Box::new(cfg))
1611    }
1612}
1613
1614impl IsqModelLoader for Qwen2Loader {
1615    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1616        Ok(vec![
1617            Regex::new(r"lm_head\.(weight|bias)$")?,
1618            // Attention
1619            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1620            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1621            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1622            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1623            // MLP
1624            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1625            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1626            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1627        ])
1628    }
1629    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1630        self.isq_layer_regexes(config)
1631    }
1632}
1633
1634impl DeviceMappedModelLoader for Qwen2Loader {
1635    fn mapped_max_act_size_elems(
1636        &self,
1637        config: &str,
1638        params: &AutoDeviceMapParams,
1639        prompt_chunksize: usize,
1640    ) -> Result<usize> {
1641        let AutoDeviceMapParams::Text {
1642            max_seq_len: _,
1643            max_batch_size,
1644        } = params
1645        else {
1646            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1647        };
1648
1649        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1650
1651        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1652    }
1653    fn non_mapped_max_act_size_elems(
1654        &self,
1655        _config: &str,
1656        _params: &AutoDeviceMapParams,
1657    ) -> Result<usize> {
1658        Ok(0)
1659    }
1660
1661    fn non_mapped_size_in_bytes(
1662        &self,
1663        config: &str,
1664        dtype: DType,
1665        weight_pack_factor: usize,
1666        _matformer_config: Option<&MatformerSliceConfig>,
1667    ) -> Result<usize> {
1668        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1669
1670        let elems = {
1671            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1672            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1673            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1674                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1675            } else {
1676                0
1677            };
1678            let norm = cfg.hidden_size;
1679            embed_tokens + lm_head + norm
1680        };
1681        Ok(elems * dtype.size_in_bytes())
1682    }
1683
1684    fn layer_sizes_in_bytes(
1685        &self,
1686        config: &str,
1687        dtype: DType,
1688        weight_pack_factor: usize,
1689        _matformer_config: Option<&MatformerSliceConfig>,
1690    ) -> Result<Vec<usize>> {
1691        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1692
1693        let per_layer_elems = {
1694            let input_layernorm = cfg.hidden_size;
1695            let post_attention_layernorm = cfg.hidden_size;
1696
1697            let size_in = cfg.hidden_size;
1698            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1699            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1700            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1701            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1702            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1703            let o_proj = size_q * size_in / weight_pack_factor;
1704
1705            let h_size = cfg.hidden_size;
1706            let i_size = cfg.intermediate_size;
1707            let gate_proj = h_size * i_size / weight_pack_factor;
1708            let up_proj = h_size * i_size / weight_pack_factor;
1709            let down_proj = i_size * h_size / weight_pack_factor;
1710
1711            input_layernorm
1712                + post_attention_layernorm
1713                + q_proj
1714                + k_proj
1715                + v_proj
1716                + o_proj
1717                + gate_proj
1718                + up_proj
1719                + down_proj
1720        };
1721        Ok(vec![
1722            per_layer_elems * dtype.size_in_bytes();
1723            cfg.num_hidden_layers
1724        ])
1725    }
1726
1727    fn num_layers(&self, config: &str) -> Result<usize> {
1728        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1729
1730        Ok(cfg.num_hidden_layers)
1731    }
1732
1733    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1734        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1735
1736        let cfg = ModelConfigMetadata {
1737            max_seq_len: cfg.max_position_embeddings,
1738            num_layers: cfg.num_hidden_layers,
1739            hidden_size: cfg.hidden_size,
1740            num_kv_heads: cfg.num_key_value_heads,
1741            num_attn_heads: cfg.num_attention_heads,
1742            sliding_window: cfg.sliding_window,
1743            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1744            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1745        };
1746
1747        Ok(Box::new(cfg))
1748    }
1749}
1750
1751// ======================== Gemma2 loader
1752
1753/// [`NormalLoader`] for a Gemma2 model.
1754///
1755/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1756pub struct Gemma2Loader;
1757
1758impl NormalModelLoader for Gemma2Loader {
1759    fn load(
1760        &self,
1761        config: &str,
1762        vb: ShardedVarBuilder,
1763        normal_loading_metadata: NormalLoadingMetadata,
1764        attention_mechanism: AttentionImplementation,
1765    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1766        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1767
1768        Ok(Box::new(models::gemma2::Model::new(
1769            &cfg,
1770            vb,
1771            self.is_gptx(config)?,
1772            normal_loading_metadata,
1773            attention_mechanism,
1774        )?))
1775    }
1776    fn load_xlora(
1777        &self,
1778        config: &str,
1779        vb: ShardedVarBuilder,
1780        lora_config: &[((String, String), LoraConfig)],
1781        xlora_config: Option<XLoraConfig>,
1782        xlora_ordering: Ordering,
1783        normal_loading_metadata: NormalLoadingMetadata,
1784        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1785    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1786        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1787
1788        Ok(Box::new(xlora_models::XLoraGemma2::new(
1789            &cfg,
1790            vb,
1791            lora_config,
1792            xlora_config,
1793            xlora_ordering,
1794            self.is_gptx(config)?,
1795            normal_loading_metadata,
1796            preload_adapters,
1797        )?))
1798    }
1799    fn is_gptx(&self, _: &str) -> Result<bool> {
1800        Ok(true)
1801    }
1802    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1803        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1804
1805        Ok(Box::new(cfg))
1806    }
1807}
1808
1809impl IsqModelLoader for Gemma2Loader {
1810    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1811        Ok(vec![
1812            Regex::new(r"lm_head\.(weight|bias)$")?,
1813            // Attention
1814            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1815            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1816            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1817            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1818            // MLP
1819            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1820            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1821            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1822        ])
1823    }
1824    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1825        self.isq_layer_regexes(config)
1826    }
1827}
1828
1829impl DeviceMappedModelLoader for Gemma2Loader {
1830    fn mapped_max_act_size_elems(
1831        &self,
1832        config: &str,
1833        params: &AutoDeviceMapParams,
1834        prompt_chunksize: usize,
1835    ) -> Result<usize> {
1836        let AutoDeviceMapParams::Text {
1837            max_seq_len: _,
1838            max_batch_size,
1839        } = params
1840        else {
1841            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1842        };
1843
1844        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1845
1846        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1847    }
1848    fn non_mapped_max_act_size_elems(
1849        &self,
1850        _config: &str,
1851        _params: &AutoDeviceMapParams,
1852    ) -> Result<usize> {
1853        Ok(0)
1854    }
1855
1856    fn non_mapped_size_in_bytes(
1857        &self,
1858        config: &str,
1859        dtype: DType,
1860        weight_pack_factor: usize,
1861        _matformer_config: Option<&MatformerSliceConfig>,
1862    ) -> Result<usize> {
1863        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1864
1865        let elems = {
1866            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1867            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1868            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1869                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1870            } else {
1871                0
1872            };
1873            let norm = cfg.hidden_size;
1874            embed_tokens + lm_head + norm
1875        };
1876        Ok(elems * dtype.size_in_bytes())
1877    }
1878
1879    fn layer_sizes_in_bytes(
1880        &self,
1881        config: &str,
1882        dtype: DType,
1883        weight_pack_factor: usize,
1884        _matformer_config: Option<&MatformerSliceConfig>,
1885    ) -> Result<Vec<usize>> {
1886        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1887
1888        let per_layer_elems = {
1889            let input_layernorm = cfg.hidden_size;
1890            let post_attention_layernorm = cfg.hidden_size;
1891
1892            let size_in = cfg.hidden_size;
1893            let size_q = cfg.head_dim * cfg.num_attention_heads;
1894            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1895            let q_proj =
1896                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1897            let k_proj =
1898                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1899            let v_proj =
1900                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1901            let o_proj =
1902                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1903
1904            let h_size = cfg.hidden_size;
1905            let i_size = cfg.intermediate_size;
1906            let gate_proj = h_size * i_size / weight_pack_factor;
1907            let up_proj = h_size * i_size / weight_pack_factor;
1908            let down_proj = i_size * h_size / weight_pack_factor;
1909
1910            input_layernorm
1911                + post_attention_layernorm
1912                + q_proj
1913                + k_proj
1914                + v_proj
1915                + o_proj
1916                + gate_proj
1917                + up_proj
1918                + down_proj
1919        };
1920        Ok(vec![
1921            per_layer_elems * dtype.size_in_bytes();
1922            cfg.num_hidden_layers
1923        ])
1924    }
1925
1926    fn num_layers(&self, config: &str) -> Result<usize> {
1927        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1928
1929        Ok(cfg.num_hidden_layers)
1930    }
1931    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1932        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1933
1934        let cfg = ModelConfigMetadata {
1935            max_seq_len: cfg.max_position_embeddings,
1936            num_layers: cfg.num_hidden_layers,
1937            hidden_size: cfg.hidden_size,
1938            num_kv_heads: cfg.num_key_value_heads,
1939            num_attn_heads: cfg.num_attention_heads,
1940            sliding_window: None, // None to be more forgiving, some do not
1941            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1942            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1943        };
1944
1945        Ok(Box::new(cfg))
1946    }
1947}
1948
1949// ======================== Starcoder2 loader
1950
1951/// [`NormalLoader`] for a Starcoder2 model.
1952///
1953/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1954pub struct Starcoder2Loader;
1955
1956impl NormalModelLoader for Starcoder2Loader {
1957    fn load(
1958        &self,
1959        config: &str,
1960        vb: ShardedVarBuilder,
1961        normal_loading_metadata: NormalLoadingMetadata,
1962        attention_mechanism: AttentionImplementation,
1963    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1964        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1965
1966        Ok(Box::new(models::starcoder2::Model::new(
1967            &cfg,
1968            vb,
1969            self.is_gptx(config)?,
1970            normal_loading_metadata,
1971            attention_mechanism,
1972        )?))
1973    }
1974    fn load_xlora(
1975        &self,
1976        config: &str,
1977        vb: ShardedVarBuilder,
1978        lora_config: &[((String, String), LoraConfig)],
1979        xlora_config: Option<XLoraConfig>,
1980        xlora_ordering: Ordering,
1981        normal_loading_metadata: NormalLoadingMetadata,
1982        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1983    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1984        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1985
1986        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1987            &cfg,
1988            vb,
1989            lora_config,
1990            xlora_config,
1991            xlora_ordering,
1992            self.is_gptx(config)?,
1993            normal_loading_metadata,
1994            preload_adapters,
1995        )?))
1996    }
1997    fn is_gptx(&self, _: &str) -> Result<bool> {
1998        Ok(true)
1999    }
2000    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2001        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2002
2003        Ok(Box::new(cfg))
2004    }
2005}
2006
2007impl IsqModelLoader for Starcoder2Loader {
2008    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2009        Ok(vec![
2010            Regex::new(r"lm_head\.(weight|bias)$")?,
2011            // Attention
2012            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2013            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2014            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2015            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2016            // MLP
2017            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2018            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2019        ])
2020    }
2021    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2022        self.isq_layer_regexes(config)
2023    }
2024}
2025
2026impl DeviceMappedModelLoader for Starcoder2Loader {
2027    fn mapped_max_act_size_elems(
2028        &self,
2029        config: &str,
2030        params: &AutoDeviceMapParams,
2031        prompt_chunksize: usize,
2032    ) -> Result<usize> {
2033        let AutoDeviceMapParams::Text {
2034            max_seq_len: _,
2035            max_batch_size,
2036        } = params
2037        else {
2038            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2039        };
2040
2041        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2042
2043        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2044    }
2045    fn non_mapped_max_act_size_elems(
2046        &self,
2047        _config: &str,
2048        _params: &AutoDeviceMapParams,
2049    ) -> Result<usize> {
2050        Ok(0)
2051    }
2052
2053    fn non_mapped_size_in_bytes(
2054        &self,
2055        config: &str,
2056        dtype: DType,
2057        weight_pack_factor: usize,
2058        _matformer_config: Option<&MatformerSliceConfig>,
2059    ) -> Result<usize> {
2060        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2061
2062        let elems = {
2063            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2064            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2065            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2066                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2067            } else {
2068                0
2069            };
2070            let norm = cfg.hidden_size + cfg.hidden_size;
2071            embed_tokens + lm_head + norm
2072        };
2073        Ok(elems * dtype.size_in_bytes())
2074    }
2075
2076    fn layer_sizes_in_bytes(
2077        &self,
2078        config: &str,
2079        dtype: DType,
2080        weight_pack_factor: usize,
2081        _matformer_config: Option<&MatformerSliceConfig>,
2082    ) -> Result<Vec<usize>> {
2083        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2084
2085        let per_layer_elems = {
2086            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2087            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2088
2089            let size_in = cfg.hidden_size;
2090            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2091            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2092            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2093            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2094            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2095            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2096
2097            let h_size = cfg.hidden_size;
2098            let i_size = cfg.intermediate_size;
2099            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2100            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2101
2102            input_layernorm
2103                + post_attention_layernorm
2104                + q_proj
2105                + k_proj
2106                + v_proj
2107                + o_proj
2108                + fc1
2109                + fc2
2110        };
2111        Ok(vec![
2112            per_layer_elems * dtype.size_in_bytes();
2113            cfg.num_hidden_layers
2114        ])
2115    }
2116
2117    fn num_layers(&self, config: &str) -> Result<usize> {
2118        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2119
2120        Ok(cfg.num_hidden_layers)
2121    }
2122
2123    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2124        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2125
2126        let cfg = ModelConfigMetadata {
2127            max_seq_len: cfg.max_position_embeddings,
2128            num_layers: cfg.num_hidden_layers,
2129            hidden_size: cfg.hidden_size,
2130            num_kv_heads: cfg.num_key_value_heads,
2131            num_attn_heads: cfg.num_attention_heads,
2132            sliding_window: cfg.sliding_window,
2133            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2134            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2135        };
2136
2137        Ok(Box::new(cfg))
2138    }
2139}
2140
2141// ======================== Phi3 loader
2142
2143/// [`NormalLoader`] for a Phi 3.5 MoE model.
2144///
2145/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2146pub struct Phi3_5MoELoader;
2147
2148impl NormalModelLoader for Phi3_5MoELoader {
2149    fn load(
2150        &self,
2151        config: &str,
2152        vb: ShardedVarBuilder,
2153        normal_loading_metadata: NormalLoadingMetadata,
2154        attention_mechanism: AttentionImplementation,
2155    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2156        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2157
2158        Ok(Box::new(models::phi3_5_moe::Model::new(
2159            &cfg,
2160            vb,
2161            self.is_gptx(config)?,
2162            normal_loading_metadata,
2163            attention_mechanism,
2164        )?))
2165    }
2166    fn load_xlora(
2167        &self,
2168        config: &str,
2169        vb: ShardedVarBuilder,
2170        lora_config: &[((String, String), LoraConfig)],
2171        xlora_config: Option<XLoraConfig>,
2172        xlora_ordering: Ordering,
2173        normal_loading_metadata: NormalLoadingMetadata,
2174        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2175    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2176        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2177
2178        Ok(Box::new(xlora_models::XLoraPhi3::new(
2179            &cfg,
2180            vb,
2181            lora_config,
2182            xlora_config,
2183            xlora_ordering,
2184            self.is_gptx(config)?,
2185            normal_loading_metadata,
2186            preload_adapters,
2187        )?))
2188    }
2189    fn is_gptx(&self, _: &str) -> Result<bool> {
2190        Ok(true)
2191    }
2192    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2193        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2194
2195        Ok(Box::new(cfg))
2196    }
2197}
2198
2199impl IsqModelLoader for Phi3_5MoELoader {
2200    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2201        Ok(vec![
2202            Regex::new(r"lm_head\.(weight|bias)$")?,
2203            // Attention
2204            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2205            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2206            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2207            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2208            // MLP
2209            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2210            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2211            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2212        ])
2213    }
2214    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2215        self.isq_layer_regexes(config)
2216    }
2217
2218    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2219        Ok(vec![
2220            Regex::new(r"lm_head\.(weight|bias)$")?,
2221            // MLP
2222            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2223            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2224            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2225        ])
2226    }
2227    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2228        self.isq_layer_regexes_moqe(config)
2229    }
2230}
2231
2232impl DeviceMappedModelLoader for Phi3_5MoELoader {
2233    fn mapped_max_act_size_elems(
2234        &self,
2235        config: &str,
2236        params: &AutoDeviceMapParams,
2237        prompt_chunksize: usize,
2238    ) -> Result<usize> {
2239        let AutoDeviceMapParams::Text {
2240            max_seq_len: _,
2241            max_batch_size,
2242        } = params
2243        else {
2244            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2245        };
2246
2247        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2248
2249        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2250    }
2251    fn non_mapped_max_act_size_elems(
2252        &self,
2253        _config: &str,
2254        _params: &AutoDeviceMapParams,
2255    ) -> Result<usize> {
2256        Ok(0)
2257    }
2258
2259    fn non_mapped_size_in_bytes(
2260        &self,
2261        config: &str,
2262        dtype: DType,
2263        weight_pack_factor: usize,
2264        _matformer_config: Option<&MatformerSliceConfig>,
2265    ) -> Result<usize> {
2266        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2267
2268        let elems = {
2269            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2270            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2271            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2272                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2273            } else {
2274                0
2275            };
2276            let norm = cfg.hidden_size;
2277            embed_tokens + lm_head + norm
2278        };
2279        Ok(elems * dtype.size_in_bytes())
2280    }
2281
2282    fn layer_sizes_in_bytes(
2283        &self,
2284        config: &str,
2285        dtype: DType,
2286        weight_pack_factor: usize,
2287        _matformer_config: Option<&MatformerSliceConfig>,
2288    ) -> Result<Vec<usize>> {
2289        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2290
2291        let per_layer_elems = {
2292            let input_layernorm = cfg.hidden_size;
2293            let post_attention_layernorm = cfg.hidden_size;
2294
2295            let size_in = cfg.hidden_size;
2296            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2297            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2298            let q_proj =
2299                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2300            let k_proj =
2301                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2302            let v_proj =
2303                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2304            let o_proj =
2305                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2306
2307            let moe_block = {
2308                let gate = cfg.hidden_size * cfg.num_local_experts;
2309                // Assume quantizing weight pack factor
2310                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2311                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2312                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2313                gate + cfg.num_local_experts * w1
2314                    + cfg.num_local_experts * w2
2315                    + cfg.num_local_experts * w3
2316            };
2317
2318            input_layernorm
2319                + post_attention_layernorm
2320                + q_proj
2321                + k_proj
2322                + v_proj
2323                + o_proj
2324                + moe_block
2325        };
2326        Ok(vec![
2327            per_layer_elems * dtype.size_in_bytes();
2328            cfg.num_hidden_layers
2329        ])
2330    }
2331
2332    fn num_layers(&self, config: &str) -> Result<usize> {
2333        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2334
2335        Ok(cfg.num_hidden_layers)
2336    }
2337
2338    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2339        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2340
2341        let cfg = ModelConfigMetadata {
2342            max_seq_len: cfg.max_position_embeddings,
2343            num_layers: cfg.num_hidden_layers,
2344            hidden_size: cfg.hidden_size,
2345            num_kv_heads: cfg.num_key_value_heads,
2346            num_attn_heads: cfg.num_attention_heads,
2347            sliding_window: cfg.sliding_window,
2348            k_head_dim: cfg.head_dim(),
2349            v_head_dim: cfg.head_dim(),
2350        };
2351
2352        Ok(Box::new(cfg))
2353    }
2354}
2355
2356/// [`NormalLoader`] for a DeepSeekV2 model.
2357///
2358/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2359pub struct DeepSeekV2Loader;
2360
2361impl NormalModelLoader for DeepSeekV2Loader {
2362    fn load(
2363        &self,
2364        config: &str,
2365        vb: ShardedVarBuilder,
2366        normal_loading_metadata: NormalLoadingMetadata,
2367        attention_mechanism: AttentionImplementation,
2368    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2369        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2370
2371        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2372            &cfg,
2373            vb,
2374            self.is_gptx(config)?,
2375            normal_loading_metadata,
2376            attention_mechanism,
2377        )?))
2378    }
2379    fn load_xlora(
2380        &self,
2381        _config: &str,
2382        _vb: ShardedVarBuilder,
2383        _lora_config: &[((String, String), LoraConfig)],
2384        _xlora_config: Option<XLoraConfig>,
2385        _xlora_ordering: Ordering,
2386        _normal_loading_metadata: NormalLoadingMetadata,
2387        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2388    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2389        todo!()
2390    }
2391    fn is_gptx(&self, _: &str) -> Result<bool> {
2392        Ok(true)
2393    }
2394    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2395        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2396        Ok(Box::new(cfg))
2397    }
2398}
2399
2400impl IsqModelLoader for DeepSeekV2Loader {
2401    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2402        let mut data = vec![
2403            Regex::new(r"lm_head\.(weight|bias)$")?,
2404            // Attention
2405            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2406            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2407            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2408        ];
2409        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2410        if cfg.q_lora_rank.is_some() {
2411            data.extend(vec![
2412                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2413                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2414            ]);
2415        } else {
2416            data.push(Regex::new(
2417                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2418            )?);
2419        }
2420        for layer_idx in 0..cfg.num_hidden_layers {
2421            if cfg.n_routed_experts.is_some()
2422                && layer_idx >= cfg.first_k_dense_replace
2423                && layer_idx % cfg.moe_layer_freq == 0
2424            {
2425                for i in 0..cfg.n_routed_experts.unwrap() {
2426                    data.extend(vec![
2427                        Regex::new(&format!(
2428                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2429                        ))?,
2430                        Regex::new(&format!(
2431                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2432                        ))?,
2433                        Regex::new(&format!(
2434                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2435                        ))?,
2436                    ]);
2437                }
2438                if cfg.n_shared_experts.is_some() {
2439                    data.extend(vec![
2440                        Regex::new(&format!(
2441                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2442                        ))?,
2443                        Regex::new(&format!(
2444                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2445                        ))?,
2446                        Regex::new(&format!(
2447                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2448                        ))?,
2449                    ]);
2450                }
2451            } else {
2452                data.extend(vec![
2453                    Regex::new(&format!(
2454                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2455                    ))?,
2456                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2457                    Regex::new(&format!(
2458                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2459                    ))?,
2460                ]);
2461            };
2462        }
2463        Ok(data)
2464    }
2465    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2466        self.isq_layer_regexes(config)
2467    }
2468
2469    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2470        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2471        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2472        for layer_idx in 0..cfg.num_hidden_layers {
2473            if cfg.n_routed_experts.is_some()
2474                && layer_idx >= cfg.first_k_dense_replace
2475                && layer_idx % cfg.moe_layer_freq == 0
2476            {
2477                for i in 0..cfg.n_routed_experts.unwrap() {
2478                    data.extend(vec![
2479                        Regex::new(&format!(
2480                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2481                        ))?,
2482                        Regex::new(&format!(
2483                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2484                        ))?,
2485                        Regex::new(&format!(
2486                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2487                        ))?,
2488                    ]);
2489                }
2490                if cfg.n_shared_experts.is_some() {
2491                    data.extend(vec![
2492                        Regex::new(&format!(
2493                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2494                        ))?,
2495                        Regex::new(&format!(
2496                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2497                        ))?,
2498                        Regex::new(&format!(
2499                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2500                        ))?,
2501                    ]);
2502                }
2503            } else {
2504                data.extend(vec![
2505                    Regex::new(&format!(
2506                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2507                    ))?,
2508                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2509                    Regex::new(&format!(
2510                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2511                    ))?,
2512                ]);
2513            };
2514        }
2515        Ok(data)
2516    }
2517    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2518        self.isq_layer_regexes_moqe(config)
2519    }
2520}
2521
2522impl DeviceMappedModelLoader for DeepSeekV2Loader {
2523    fn mapped_max_act_size_elems(
2524        &self,
2525        config: &str,
2526        params: &AutoDeviceMapParams,
2527        prompt_chunksize: usize,
2528    ) -> Result<usize> {
2529        let AutoDeviceMapParams::Text {
2530            max_seq_len: _,
2531            max_batch_size,
2532        } = params
2533        else {
2534            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2535        };
2536
2537        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2538
2539        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2540    }
2541    fn non_mapped_max_act_size_elems(
2542        &self,
2543        _config: &str,
2544        _params: &AutoDeviceMapParams,
2545    ) -> Result<usize> {
2546        Ok(0)
2547    }
2548
2549    fn non_mapped_size_in_bytes(
2550        &self,
2551        config: &str,
2552        dtype: DType,
2553        weight_pack_factor: usize,
2554        _matformer_config: Option<&MatformerSliceConfig>,
2555    ) -> Result<usize> {
2556        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2557        let elems = {
2558            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2559            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2560            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2561                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2562            } else {
2563                0
2564            };
2565            let norm = cfg.hidden_size;
2566            embed_tokens + lm_head + norm
2567        };
2568        Ok(elems * dtype.size_in_bytes())
2569    }
2570
2571    fn layer_sizes_in_bytes(
2572        &self,
2573        config: &str,
2574        dtype: DType,
2575        weight_pack_factor: usize,
2576        _matformer_config: Option<&MatformerSliceConfig>,
2577    ) -> Result<Vec<usize>> {
2578        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2579        let mut per_layer_elems = Vec::new();
2580
2581        for layer_idx in 0..cfg.num_hidden_layers {
2582            let input_layernorm = cfg.hidden_size;
2583            let post_attention_layernorm = cfg.hidden_size;
2584
2585            let q_proj = match cfg.q_lora_rank {
2586                Some(lora_rank) => {
2587                    let a = cfg.hidden_size * lora_rank;
2588                    let norm = lora_rank;
2589                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2590                    a + norm + b
2591                }
2592                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2593            };
2594            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2595                / weight_pack_factor
2596                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2597            let kv_a_layernorm = cfg.kv_lora_rank;
2598            let kv_b_proj = cfg.kv_lora_rank
2599                * cfg.num_attention_heads
2600                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2601                / weight_pack_factor;
2602            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2603                / weight_pack_factor
2604                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2605
2606            let moe_block = {
2607                let mut sum = 0;
2608                if cfg.n_routed_experts.is_some()
2609                    && layer_idx >= cfg.first_k_dense_replace
2610                    && layer_idx % cfg.moe_layer_freq == 0
2611                {
2612                    let h_size = cfg.hidden_size;
2613                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2614                        * cfg.n_routed_experts.unwrap();
2615                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2616                        * cfg.n_routed_experts.unwrap();
2617                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2618                        * cfg.n_routed_experts.unwrap();
2619                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2620                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2621                            / weight_pack_factor;
2622                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2623                            / weight_pack_factor;
2624                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2625                            / weight_pack_factor;
2626                        gate_proj + up_proj + down_proj
2627                    } else {
2628                        0
2629                    };
2630                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2631                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2632                } else {
2633                    let h_size = cfg.hidden_size;
2634                    let i_size = cfg.intermediate_size;
2635                    let gate_proj = h_size * i_size / weight_pack_factor;
2636                    let up_proj = h_size * i_size / weight_pack_factor;
2637                    let down_proj = i_size * h_size / weight_pack_factor;
2638                    sum += gate_proj + up_proj + down_proj;
2639                }
2640                sum
2641            };
2642
2643            per_layer_elems.push(
2644                input_layernorm
2645                    + post_attention_layernorm
2646                    + q_proj
2647                    + kv_a_layernorm
2648                    + kv_a_proj_with_mqa
2649                    + kv_b_proj
2650                    + o_proj
2651                    + moe_block,
2652            );
2653        }
2654
2655        Ok(per_layer_elems
2656            .into_iter()
2657            .map(|x| x * dtype.size_in_bytes())
2658            .collect())
2659    }
2660
2661    fn num_layers(&self, config: &str) -> Result<usize> {
2662        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2663        Ok(cfg.num_hidden_layers)
2664    }
2665
2666    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2667        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2668
2669        let cfg = ModelConfigMetadata {
2670            max_seq_len: cfg.max_position_embeddings,
2671            num_layers: cfg.num_hidden_layers,
2672            hidden_size: cfg.hidden_size,
2673            num_kv_heads: cfg.num_attention_heads,
2674            num_attn_heads: cfg.num_attention_heads,
2675            sliding_window: None,
2676            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2677            v_head_dim: cfg.v_head_dim,
2678        };
2679
2680        Ok(Box::new(cfg))
2681    }
2682}
2683
2684/// [`NormalLoader`] for a DeepSeekV3 model.
2685///
2686/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2687pub struct DeepSeekV3Loader;
2688
2689impl NormalModelLoader for DeepSeekV3Loader {
2690    fn load(
2691        &self,
2692        config: &str,
2693        vb: ShardedVarBuilder,
2694        normal_loading_metadata: NormalLoadingMetadata,
2695        attention_mechanism: AttentionImplementation,
2696    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2697        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2698        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2699            &cfg,
2700            vb,
2701            self.is_gptx(config)?,
2702            normal_loading_metadata,
2703            attention_mechanism,
2704        )?))
2705    }
2706    fn load_xlora(
2707        &self,
2708        _config: &str,
2709        _vb: ShardedVarBuilder,
2710        _lora_config: &[((String, String), LoraConfig)],
2711        _xlora_config: Option<XLoraConfig>,
2712        _xlora_ordering: Ordering,
2713        _normal_loading_metadata: NormalLoadingMetadata,
2714        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2715    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2716        todo!()
2717    }
2718    fn is_gptx(&self, _: &str) -> Result<bool> {
2719        Ok(true)
2720    }
2721    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2722        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2723        Ok(Box::new(cfg))
2724    }
2725}
2726
2727impl IsqModelLoader for DeepSeekV3Loader {
2728    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2729        let mut data = vec![
2730            Regex::new(r"lm_head\.(weight|bias)$")?,
2731            // Attention
2732            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2733            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2734            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2735        ];
2736        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2737        if cfg.q_lora_rank.is_some() {
2738            data.extend(vec![
2739                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2740                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2741            ]);
2742        } else {
2743            data.push(Regex::new(
2744                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2745            )?);
2746        }
2747        for layer_idx in 0..cfg.num_hidden_layers {
2748            if cfg.n_routed_experts.is_some()
2749                && layer_idx >= cfg.first_k_dense_replace
2750                && layer_idx % cfg.moe_layer_freq == 0
2751            {
2752                for i in 0..cfg.n_routed_experts.unwrap() {
2753                    data.extend(vec![
2754                        Regex::new(&format!(
2755                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2756                        ))?,
2757                        Regex::new(&format!(
2758                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2759                        ))?,
2760                        Regex::new(&format!(
2761                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2762                        ))?,
2763                    ]);
2764                }
2765                if cfg.n_shared_experts.is_some() {
2766                    data.extend(vec![
2767                        Regex::new(&format!(
2768                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2769                        ))?,
2770                        Regex::new(&format!(
2771                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2772                        ))?,
2773                        Regex::new(&format!(
2774                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2775                        ))?,
2776                    ]);
2777                }
2778            } else {
2779                data.extend(vec![
2780                    Regex::new(&format!(
2781                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2782                    ))?,
2783                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2784                    Regex::new(&format!(
2785                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2786                    ))?,
2787                ]);
2788            };
2789        }
2790        Ok(data)
2791    }
2792    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2793        self.isq_layer_regexes(config)
2794    }
2795
2796    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2797        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2798        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2799        for layer_idx in 0..cfg.num_hidden_layers {
2800            if cfg.n_routed_experts.is_some()
2801                && layer_idx >= cfg.first_k_dense_replace
2802                && layer_idx % cfg.moe_layer_freq == 0
2803            {
2804                for i in 0..cfg.n_routed_experts.unwrap() {
2805                    data.extend(vec![
2806                        Regex::new(&format!(
2807                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2808                        ))?,
2809                        Regex::new(&format!(
2810                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2811                        ))?,
2812                        Regex::new(&format!(
2813                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2814                        ))?,
2815                    ]);
2816                }
2817                if cfg.n_shared_experts.is_some() {
2818                    data.extend(vec![
2819                        Regex::new(&format!(
2820                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2821                        ))?,
2822                        Regex::new(&format!(
2823                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2824                        ))?,
2825                        Regex::new(&format!(
2826                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2827                        ))?,
2828                    ]);
2829                }
2830            } else {
2831                data.extend(vec![
2832                    Regex::new(&format!(
2833                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2834                    ))?,
2835                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2836                    Regex::new(&format!(
2837                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2838                    ))?,
2839                ]);
2840            };
2841        }
2842        Ok(data)
2843    }
2844    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2845        self.isq_layer_regexes_moqe(config)
2846    }
2847}
2848
2849impl DeviceMappedModelLoader for DeepSeekV3Loader {
2850    fn mapped_max_act_size_elems(
2851        &self,
2852        config: &str,
2853        params: &AutoDeviceMapParams,
2854        prompt_chunksize: usize,
2855    ) -> Result<usize> {
2856        let AutoDeviceMapParams::Text {
2857            max_seq_len: _,
2858            max_batch_size,
2859        } = params
2860        else {
2861            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2862        };
2863
2864        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2865
2866        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2867    }
2868    fn non_mapped_max_act_size_elems(
2869        &self,
2870        _config: &str,
2871        _params: &AutoDeviceMapParams,
2872    ) -> Result<usize> {
2873        Ok(0)
2874    }
2875
2876    fn non_mapped_size_in_bytes(
2877        &self,
2878        config: &str,
2879        dtype: DType,
2880        weight_pack_factor: usize,
2881        _matformer_config: Option<&MatformerSliceConfig>,
2882    ) -> Result<usize> {
2883        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2884        let elems = {
2885            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2886            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2887            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2888                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2889            } else {
2890                0
2891            };
2892            let norm = cfg.hidden_size;
2893            embed_tokens + lm_head + norm
2894        };
2895        Ok(elems * dtype.size_in_bytes())
2896    }
2897
2898    fn layer_sizes_in_bytes(
2899        &self,
2900        config: &str,
2901        dtype: DType,
2902        weight_pack_factor: usize,
2903        _matformer_config: Option<&MatformerSliceConfig>,
2904    ) -> Result<Vec<usize>> {
2905        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2906        let mut per_layer_elems = Vec::new();
2907
2908        for layer_idx in 0..cfg.num_hidden_layers {
2909            let input_layernorm = cfg.hidden_size;
2910            let post_attention_layernorm = cfg.hidden_size;
2911
2912            let q_proj = match cfg.q_lora_rank {
2913                Some(lora_rank) => {
2914                    let a = cfg.hidden_size * lora_rank;
2915                    let norm = lora_rank;
2916                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2917                    a + norm + b
2918                }
2919                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2920            };
2921            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2922                / weight_pack_factor
2923                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2924            let kv_a_layernorm = cfg.kv_lora_rank;
2925            let kv_b_proj = cfg.kv_lora_rank
2926                * cfg.num_attention_heads
2927                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2928                / weight_pack_factor;
2929            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2930                / weight_pack_factor
2931                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2932
2933            let moe_block = {
2934                let mut sum = 0;
2935                if cfg.n_routed_experts.is_some()
2936                    && layer_idx >= cfg.first_k_dense_replace
2937                    && layer_idx % cfg.moe_layer_freq == 0
2938                {
2939                    let h_size = cfg.hidden_size;
2940                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2941                        * cfg.n_routed_experts.unwrap();
2942                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2943                        * cfg.n_routed_experts.unwrap();
2944                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2945                        * cfg.n_routed_experts.unwrap();
2946                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2947                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2948                            / weight_pack_factor;
2949                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2950                            / weight_pack_factor;
2951                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2952                            / weight_pack_factor;
2953                        gate_proj + up_proj + down_proj
2954                    } else {
2955                        0
2956                    };
2957                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2958                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2959                } else {
2960                    let h_size = cfg.hidden_size;
2961                    let i_size = cfg.intermediate_size;
2962                    let gate_proj = h_size * i_size / weight_pack_factor;
2963                    let up_proj = h_size * i_size / weight_pack_factor;
2964                    let down_proj = i_size * h_size / weight_pack_factor;
2965                    sum += gate_proj + up_proj + down_proj;
2966                }
2967                sum
2968            };
2969
2970            per_layer_elems.push(
2971                input_layernorm
2972                    + post_attention_layernorm
2973                    + q_proj
2974                    + kv_a_layernorm
2975                    + kv_a_proj_with_mqa
2976                    + kv_b_proj
2977                    + o_proj
2978                    + moe_block,
2979            );
2980        }
2981
2982        Ok(per_layer_elems
2983            .into_iter()
2984            .map(|x| x * dtype.size_in_bytes())
2985            .collect())
2986    }
2987
2988    fn num_layers(&self, config: &str) -> Result<usize> {
2989        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2990        Ok(cfg.num_hidden_layers)
2991    }
2992
2993    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2994        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2995
2996        let cfg = ModelConfigMetadata {
2997            max_seq_len: cfg.max_position_embeddings,
2998            num_layers: cfg.num_hidden_layers,
2999            hidden_size: cfg.hidden_size,
3000            num_kv_heads: cfg.num_attention_heads,
3001            num_attn_heads: cfg.num_attention_heads,
3002            sliding_window: None,
3003            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3004            v_head_dim: cfg.v_head_dim,
3005        };
3006
3007        Ok(Box::new(cfg))
3008    }
3009}
3010
3011/// [`NormalLoader`] for a Qwen 3 model.
3012///
3013/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3014pub struct Qwen3Loader;
3015
3016impl NormalModelLoader for Qwen3Loader {
3017    fn load(
3018        &self,
3019        config: &str,
3020        vb: ShardedVarBuilder,
3021        normal_loading_metadata: NormalLoadingMetadata,
3022        attention_mechanism: AttentionImplementation,
3023    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3024        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3025
3026        Ok(Box::new(models::qwen3::Model::new(
3027            &cfg,
3028            vb,
3029            self.is_gptx(config)?,
3030            normal_loading_metadata,
3031            attention_mechanism,
3032        )?))
3033    }
3034    fn load_xlora(
3035        &self,
3036        _config: &str,
3037        _vb: ShardedVarBuilder,
3038        _lora_config: &[((String, String), LoraConfig)],
3039        _xlora_config: Option<XLoraConfig>,
3040        _xlora_ordering: Ordering,
3041        _normal_loading_metadata: NormalLoadingMetadata,
3042        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3043    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3044        todo!()
3045    }
3046    fn is_gptx(&self, _: &str) -> Result<bool> {
3047        Ok(true)
3048    }
3049    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3050        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3051
3052        Ok(Box::new(cfg))
3053    }
3054}
3055
3056impl IsqModelLoader for Qwen3Loader {
3057    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3058        Ok(vec![
3059            Regex::new(r"lm_head\.(weight|bias)$")?,
3060            // Attention
3061            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3062            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3063            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3064            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3065            // MLP
3066            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3067            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3068            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3069        ])
3070    }
3071    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3072        self.isq_layer_regexes(config)
3073    }
3074}
3075
3076impl DeviceMappedModelLoader for Qwen3Loader {
3077    fn mapped_max_act_size_elems(
3078        &self,
3079        config: &str,
3080        params: &AutoDeviceMapParams,
3081        prompt_chunksize: usize,
3082    ) -> Result<usize> {
3083        let AutoDeviceMapParams::Text {
3084            max_seq_len: _,
3085            max_batch_size,
3086        } = params
3087        else {
3088            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3089        };
3090
3091        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3092
3093        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3094    }
3095    fn non_mapped_max_act_size_elems(
3096        &self,
3097        _config: &str,
3098        _params: &AutoDeviceMapParams,
3099    ) -> Result<usize> {
3100        Ok(0)
3101    }
3102
3103    fn non_mapped_size_in_bytes(
3104        &self,
3105        config: &str,
3106        dtype: DType,
3107        weight_pack_factor: usize,
3108        _matformer_config: Option<&MatformerSliceConfig>,
3109    ) -> Result<usize> {
3110        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3111        let elems = {
3112            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3113            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3114            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3115                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3116            } else {
3117                0
3118            };
3119            let norm = cfg.hidden_size;
3120            embed_tokens + lm_head + norm
3121        };
3122        Ok(elems * dtype.size_in_bytes())
3123    }
3124
3125    fn layer_sizes_in_bytes(
3126        &self,
3127        config: &str,
3128        dtype: DType,
3129        weight_pack_factor: usize,
3130        _matformer_config: Option<&MatformerSliceConfig>,
3131    ) -> Result<Vec<usize>> {
3132        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3133        let per_layer_elems = {
3134            let input_layernorm = cfg.hidden_size;
3135            let post_attention_layernorm = cfg.hidden_size;
3136
3137            let size_in = cfg.hidden_size;
3138            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3139            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3140            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3141            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3142            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3143            let o_proj = size_q * size_in / weight_pack_factor;
3144
3145            let h_size = cfg.hidden_size;
3146            let i_size = cfg.intermediate_size;
3147            let gate_proj = h_size * i_size / weight_pack_factor;
3148            let up_proj = h_size * i_size / weight_pack_factor;
3149            let down_proj = i_size * h_size / weight_pack_factor;
3150
3151            let q_norm = cfg.head_dim();
3152            let k_norm = cfg.head_dim();
3153
3154            input_layernorm
3155                + post_attention_layernorm
3156                + q_proj
3157                + k_proj
3158                + v_proj
3159                + o_proj
3160                + gate_proj
3161                + up_proj
3162                + down_proj
3163                + q_norm
3164                + k_norm
3165        };
3166        Ok(vec![
3167            per_layer_elems * dtype.size_in_bytes();
3168            cfg.num_hidden_layers
3169        ])
3170    }
3171
3172    fn num_layers(&self, config: &str) -> Result<usize> {
3173        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3174        Ok(cfg.num_hidden_layers)
3175    }
3176
3177    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3178        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3179
3180        let cfg = ModelConfigMetadata {
3181            max_seq_len: cfg.max_position_embeddings,
3182            num_layers: cfg.num_hidden_layers,
3183            hidden_size: cfg.hidden_size,
3184            num_kv_heads: cfg.num_key_value_heads,
3185            num_attn_heads: cfg.num_attention_heads,
3186            sliding_window: cfg.sliding_window,
3187            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3188            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3189        };
3190
3191        Ok(Box::new(cfg))
3192    }
3193}
3194
3195/// [`NormalLoader`] for a GLM 4 model.
3196///
3197/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3198pub struct GLM4Loader;
3199
3200impl NormalModelLoader for GLM4Loader {
3201    fn load(
3202        &self,
3203        config: &str,
3204        vb: ShardedVarBuilder,
3205        normal_loading_metadata: NormalLoadingMetadata,
3206        attention_mechanism: AttentionImplementation,
3207    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3208        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3209
3210        Ok(Box::new(models::glm4::Model::new(
3211            &cfg,
3212            vb,
3213            self.is_gptx(config)?,
3214            normal_loading_metadata,
3215            attention_mechanism,
3216        )?))
3217    }
3218    fn load_xlora(
3219        &self,
3220        _config: &str,
3221        _vb: ShardedVarBuilder,
3222        _lora_config: &[((String, String), LoraConfig)],
3223        _xlora_config: Option<XLoraConfig>,
3224        _xlora_ordering: Ordering,
3225        _normal_loading_metadata: NormalLoadingMetadata,
3226        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3227    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3228        todo!()
3229    }
3230    fn is_gptx(&self, _: &str) -> Result<bool> {
3231        Ok(true)
3232    }
3233    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3234        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3235
3236        Ok(Box::new(cfg))
3237    }
3238}
3239
3240impl IsqModelLoader for GLM4Loader {
3241    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3242        Ok(vec![
3243            Regex::new(r"lm_head\.(weight|bias)$")?,
3244            // Attention
3245            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3246            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3247            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3248            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3249            // MLP
3250            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3251            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3252            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3253        ])
3254    }
3255    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3256        self.isq_layer_regexes(config)
3257    }
3258}
3259
3260impl DeviceMappedModelLoader for GLM4Loader {
3261    fn mapped_max_act_size_elems(
3262        &self,
3263        config: &str,
3264        params: &AutoDeviceMapParams,
3265        prompt_chunksize: usize,
3266    ) -> Result<usize> {
3267        let AutoDeviceMapParams::Text {
3268            max_seq_len: _,
3269            max_batch_size,
3270        } = params
3271        else {
3272            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3273        };
3274
3275        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3276
3277        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3278    }
3279    fn non_mapped_max_act_size_elems(
3280        &self,
3281        _config: &str,
3282        _params: &AutoDeviceMapParams,
3283    ) -> Result<usize> {
3284        Ok(0)
3285    }
3286
3287    fn non_mapped_size_in_bytes(
3288        &self,
3289        config: &str,
3290        dtype: DType,
3291        weight_pack_factor: usize,
3292        _matformer_config: Option<&MatformerSliceConfig>,
3293    ) -> Result<usize> {
3294        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3295        let elems = {
3296            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3297            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3298            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3299                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3300            } else {
3301                0
3302            };
3303            let norm = cfg.hidden_size;
3304            embed_tokens + lm_head + norm
3305        };
3306        Ok(elems * dtype.size_in_bytes())
3307    }
3308
3309    fn layer_sizes_in_bytes(
3310        &self,
3311        config: &str,
3312        dtype: DType,
3313        weight_pack_factor: usize,
3314        _matformer_config: Option<&MatformerSliceConfig>,
3315    ) -> Result<Vec<usize>> {
3316        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3317        let per_layer_elems = {
3318            let input_layernorm = cfg.hidden_size;
3319            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3320
3321            let size_in = cfg.hidden_size;
3322            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3323            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3324            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3325            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3326            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3327            let o_proj = size_q * size_in / weight_pack_factor;
3328
3329            let h_size = cfg.hidden_size;
3330            let i_size = cfg.intermediate_size;
3331            let gate_proj = h_size * i_size / weight_pack_factor;
3332            let up_proj = h_size * i_size / weight_pack_factor;
3333            let down_proj = i_size * h_size / weight_pack_factor;
3334
3335            input_layernorm
3336                + post_attention_layernorm
3337                + q_proj
3338                + k_proj
3339                + v_proj
3340                + o_proj
3341                + gate_proj
3342                + up_proj
3343                + down_proj
3344        };
3345        Ok(vec![
3346            per_layer_elems * dtype.size_in_bytes();
3347            cfg.num_hidden_layers
3348        ])
3349    }
3350
3351    fn num_layers(&self, config: &str) -> Result<usize> {
3352        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3353        Ok(cfg.num_hidden_layers)
3354    }
3355
3356    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3357        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358
3359        let cfg = ModelConfigMetadata {
3360            max_seq_len: cfg.max_position_embeddings,
3361            num_layers: cfg.num_hidden_layers,
3362            hidden_size: cfg.hidden_size,
3363            num_kv_heads: cfg.num_key_value_heads,
3364            num_attn_heads: cfg.num_attention_heads,
3365            sliding_window: cfg.sliding_window,
3366            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3367            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3368        };
3369
3370        Ok(Box::new(cfg))
3371    }
3372}
3373
3374/// [`NormalLoader`] for a Qwen 3 MoE model.
3375///
3376/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3377pub struct Qwen3MoELoader;
3378
3379impl NormalModelLoader for Qwen3MoELoader {
3380    fn load(
3381        &self,
3382        config: &str,
3383        vb: ShardedVarBuilder,
3384        normal_loading_metadata: NormalLoadingMetadata,
3385        attention_mechanism: AttentionImplementation,
3386    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3387        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3388
3389        Ok(Box::new(models::qwen3_moe::Model::new(
3390            &cfg,
3391            vb,
3392            self.is_gptx(config)?,
3393            normal_loading_metadata,
3394            attention_mechanism,
3395        )?))
3396    }
3397    fn load_xlora(
3398        &self,
3399        _config: &str,
3400        _vb: ShardedVarBuilder,
3401        _lora_config: &[((String, String), LoraConfig)],
3402        _xlora_config: Option<XLoraConfig>,
3403        _xlora_ordering: Ordering,
3404        _normal_loading_metadata: NormalLoadingMetadata,
3405        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3406    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3407        todo!()
3408    }
3409    fn is_gptx(&self, _: &str) -> Result<bool> {
3410        Ok(true)
3411    }
3412    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3413        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3414
3415        Ok(Box::new(cfg))
3416    }
3417}
3418
3419impl IsqModelLoader for Qwen3MoELoader {
3420    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3421        Ok(vec![
3422            Regex::new(r"lm_head\.(weight|bias)$")?,
3423            // Attention
3424            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3425            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3426            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3427            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3428            // MLP
3429            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3430            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3431            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3432            // MLP MoE
3433            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3434            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3435            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3436        ])
3437    }
3438    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3439        self.isq_layer_regexes(config)
3440    }
3441    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3442        self.isq_layer_regexes_moqe(config)
3443    }
3444}
3445
3446impl DeviceMappedModelLoader for Qwen3MoELoader {
3447    fn mapped_max_act_size_elems(
3448        &self,
3449        config: &str,
3450        params: &AutoDeviceMapParams,
3451        prompt_chunksize: usize,
3452    ) -> Result<usize> {
3453        let AutoDeviceMapParams::Text {
3454            max_seq_len: _,
3455            max_batch_size,
3456        } = params
3457        else {
3458            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3459        };
3460
3461        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3462
3463        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3464    }
3465    fn non_mapped_max_act_size_elems(
3466        &self,
3467        _config: &str,
3468        _params: &AutoDeviceMapParams,
3469    ) -> Result<usize> {
3470        Ok(0)
3471    }
3472
3473    fn non_mapped_size_in_bytes(
3474        &self,
3475        config: &str,
3476        dtype: DType,
3477        weight_pack_factor: usize,
3478        _matformer_config: Option<&MatformerSliceConfig>,
3479    ) -> Result<usize> {
3480        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3481        let elems = {
3482            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3483            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3484            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3485                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3486            } else {
3487                0
3488            };
3489            let norm = cfg.hidden_size;
3490            embed_tokens + lm_head + norm
3491        };
3492        Ok(elems * dtype.size_in_bytes())
3493    }
3494
3495    fn layer_sizes_in_bytes(
3496        &self,
3497        config: &str,
3498        dtype: DType,
3499        weight_pack_factor: usize,
3500        _matformer_config: Option<&MatformerSliceConfig>,
3501    ) -> Result<Vec<usize>> {
3502        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3503
3504        let mut layer_sizes_in_bytes = Vec::new();
3505        for layer_idx in 0..cfg.num_hidden_layers {
3506            let input_layernorm = cfg.hidden_size;
3507            let post_attention_layernorm = cfg.hidden_size;
3508
3509            let size_in = cfg.hidden_size;
3510            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3511            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3512            let q_proj = size_in * size_q / weight_pack_factor;
3513            let k_proj = size_in * size_kv / weight_pack_factor;
3514            let v_proj = size_in * size_kv / weight_pack_factor;
3515            let o_proj = size_q * size_in / weight_pack_factor;
3516
3517            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3518                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3519            {
3520                let gate_size = cfg.hidden_size * cfg.num_experts;
3521                let expert_size = {
3522                    let h_size = cfg.hidden_size;
3523                    let i_size = cfg.moe_intermediate_size;
3524                    let gate_proj = h_size * i_size / weight_pack_factor;
3525                    let up_proj = h_size * i_size / weight_pack_factor;
3526                    let down_proj = i_size * h_size / weight_pack_factor;
3527                    gate_proj + up_proj + down_proj
3528                };
3529                expert_size * cfg.num_experts + gate_size
3530            } else {
3531                let h_size = cfg.hidden_size;
3532                let i_size = cfg.intermediate_size;
3533                let gate_proj = h_size * i_size / weight_pack_factor;
3534                let up_proj = h_size * i_size / weight_pack_factor;
3535                let down_proj = i_size * h_size / weight_pack_factor;
3536                gate_proj + up_proj + down_proj
3537            };
3538
3539            let q_norm = cfg.head_dim();
3540            let k_norm = cfg.head_dim();
3541
3542            let size_elems = input_layernorm
3543                + post_attention_layernorm
3544                + q_proj
3545                + k_proj
3546                + v_proj
3547                + o_proj
3548                + mlp_size
3549                + q_norm
3550                + k_norm;
3551
3552            let size_in_bytes = size_elems * dtype.size_in_bytes();
3553            layer_sizes_in_bytes.push(size_in_bytes);
3554        }
3555
3556        Ok(layer_sizes_in_bytes)
3557    }
3558
3559    fn num_layers(&self, config: &str) -> Result<usize> {
3560        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3561        Ok(cfg.num_hidden_layers)
3562    }
3563
3564    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3565        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3566
3567        let cfg = ModelConfigMetadata {
3568            max_seq_len: cfg.max_position_embeddings,
3569            num_layers: cfg.num_hidden_layers,
3570            hidden_size: cfg.hidden_size,
3571            num_kv_heads: cfg.num_key_value_heads,
3572            num_attn_heads: cfg.num_attention_heads,
3573            sliding_window: cfg.sliding_window,
3574            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3575            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3576        };
3577
3578        Ok(Box::new(cfg))
3579    }
3580}
3581
3582// ======================== SmolLm3 loader
3583
3584/// [`NormalLoader`] for a SmolLm3 model.
3585///
3586/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3587pub struct SmolLm3Loader;
3588
3589impl NormalModelLoader for SmolLm3Loader {
3590    fn load(
3591        &self,
3592        config: &str,
3593        vb: ShardedVarBuilder,
3594        normal_loading_metadata: NormalLoadingMetadata,
3595        attention_mechanism: AttentionImplementation,
3596    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3597        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3598
3599        Ok(Box::new(models::smollm3::SmolLm3::new(
3600            &cfg,
3601            vb,
3602            self.is_gptx(config)?,
3603            normal_loading_metadata,
3604            attention_mechanism,
3605        )?))
3606    }
3607    fn load_xlora(
3608        &self,
3609        _config: &str,
3610        _vb: ShardedVarBuilder,
3611        _lora_config: &[((String, String), LoraConfig)],
3612        _xlora_config: Option<XLoraConfig>,
3613        _xlora_ordering: Ordering,
3614        _normal_loading_metadata: NormalLoadingMetadata,
3615        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3616    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3617        todo!()
3618    }
3619    fn is_gptx(&self, _: &str) -> Result<bool> {
3620        Ok(true)
3621    }
3622    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3623        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3624        Ok(Box::new(cfg))
3625    }
3626}
3627
3628impl IsqModelLoader for SmolLm3Loader {
3629    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3630        Ok(vec![
3631            Regex::new(r"lm_head\.(weight|bias)$")?,
3632            // Attention
3633            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3634            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3635            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3636            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3637            // MLP
3638            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3639            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3640            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3641        ])
3642    }
3643    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3644        self.isq_layer_regexes(config)
3645    }
3646}
3647
3648impl DeviceMappedModelLoader for SmolLm3Loader {
3649    fn mapped_max_act_size_elems(
3650        &self,
3651        config: &str,
3652        params: &AutoDeviceMapParams,
3653        prompt_chunksize: usize,
3654    ) -> Result<usize> {
3655        let AutoDeviceMapParams::Text {
3656            max_seq_len: _,
3657            max_batch_size,
3658        } = params
3659        else {
3660            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3661        };
3662
3663        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3664
3665        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3666    }
3667    fn non_mapped_max_act_size_elems(
3668        &self,
3669        _config: &str,
3670        _params: &AutoDeviceMapParams,
3671    ) -> Result<usize> {
3672        Ok(0)
3673    }
3674
3675    fn non_mapped_size_in_bytes(
3676        &self,
3677        config: &str,
3678        dtype: DType,
3679        weight_pack_factor: usize,
3680        _matformer_config: Option<&MatformerSliceConfig>,
3681    ) -> Result<usize> {
3682        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3683
3684        let elems = {
3685            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3686            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3687            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3688                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3689            } else {
3690                0
3691            };
3692            let norm = cfg.hidden_size;
3693            embed_tokens + lm_head + norm
3694        };
3695        Ok(elems * dtype.size_in_bytes())
3696    }
3697
3698    fn layer_sizes_in_bytes(
3699        &self,
3700        config: &str,
3701        dtype: DType,
3702        weight_pack_factor: usize,
3703        _matformer_config: Option<&MatformerSliceConfig>,
3704    ) -> Result<Vec<usize>> {
3705        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3706
3707        let per_layer_elems = {
3708            let input_layernorm = cfg.hidden_size;
3709            let post_attention_layernorm = cfg.hidden_size;
3710
3711            let size_in = cfg.hidden_size;
3712            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3713            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3714            let q_proj = size_in * size_q / weight_pack_factor;
3715            let k_proj = size_in * size_kv / weight_pack_factor;
3716            let v_proj = size_in * size_kv / weight_pack_factor;
3717            let o_proj = size_q * size_in / weight_pack_factor;
3718
3719            let h_size = cfg.hidden_size;
3720            let i_size = cfg.intermediate_size;
3721            let gate_proj = h_size * i_size / weight_pack_factor;
3722            let up_proj = h_size * i_size / weight_pack_factor;
3723            let down_proj = i_size * h_size / weight_pack_factor;
3724
3725            input_layernorm
3726                + post_attention_layernorm
3727                + q_proj
3728                + k_proj
3729                + v_proj
3730                + o_proj
3731                + gate_proj
3732                + up_proj
3733                + down_proj
3734        };
3735        Ok(vec![
3736            per_layer_elems * dtype.size_in_bytes();
3737            cfg.num_hidden_layers
3738        ])
3739    }
3740
3741    fn num_layers(&self, config: &str) -> Result<usize> {
3742        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3743
3744        Ok(cfg.num_hidden_layers)
3745    }
3746    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3747        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3748
3749        let cfg = ModelConfigMetadata {
3750            max_seq_len: cfg.max_position_embeddings,
3751            num_layers: cfg.num_hidden_layers,
3752            hidden_size: cfg.hidden_size,
3753            num_kv_heads: cfg.num_key_value_heads,
3754            num_attn_heads: cfg.num_attention_heads,
3755            sliding_window: None,
3756            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3757            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3758        };
3759
3760        Ok(Box::new(cfg))
3761    }
3762}