mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    amoe::AnyMoeBaseModelMixin,
10    device_map::DeviceMapper,
11    layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig},
12    lora::{LoraConfig, Ordering},
13    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
14    pipeline::{
15        isq::IsqModelLoader,
16        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17        EitherCache, IsqModel,
18    },
19    serde_default_fn,
20    utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25
26use indicatif::MultiProgress;
27use mistralrs_quant::{QuantizedConfig, ShardedVarBuilder};
28#[cfg(feature = "pyo3_macros")]
29use pyo3::pyclass;
30
31use regex::Regex;
32use serde::Deserialize;
33
34use crate::{
35    models,
36    xlora_models::{self, XLoraConfig},
37};
38
39use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
40
41pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
42    #[allow(clippy::too_many_arguments)]
43    fn forward(
44        &self,
45        input_ids: &Tensor,
46        seqlen_offsets: &[usize],
47        context_lens: Vec<(usize, usize)>,
48        position_ids: Vec<usize>,
49        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
50        flash_params: &FlashParams,
51    ) -> candle_core::Result<Tensor>;
52    #[allow(clippy::too_many_arguments)]
53    fn xlora_forward(
54        &self,
55        input_ids: &Tensor,
56        input_ids_full: &Tensor,
57        seqlen_offsets: &[usize],
58        seqlen_offsets_full: &[usize],
59        no_kv_cache: bool,
60        non_granular_state: &Option<NonGranularState>,
61        context_lens: Vec<(usize, usize)>,
62        position_ids: Vec<usize>,
63        flash_params: &FlashParams,
64        flash_params_full: &FlashParams,
65    ) -> candle_core::Result<Tensor>;
66    fn is_xlora(&self) -> bool;
67    fn device(&self) -> &Device;
68    fn cache(&self) -> &EitherCache;
69    fn cache_mut(&mut self) -> &mut EitherCache;
70    fn max_seq_len(&self) -> usize;
71    fn config(&self) -> &ModelConfigMetadata;
72}
73
74/// Metadata for loading a model with ISQ or device mapping.
75pub struct NormalLoadingMetadata {
76    // Device mapping metadata which can be used to construct a concrete device mapper
77    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
78    // Flag to check if loading in ISQ
79    pub loading_isq: bool,
80    // Device mapping target device (the one that is not the cpu)
81    pub real_device: Device,
82    // MultiProgress support for parallelized loading
83    pub multi_progress: Arc<MultiProgress>,
84}
85
86pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87    fn load(
88        &self,
89        config: &str,
90        use_flash_attn: bool,
91        vb: ShardedVarBuilder,
92        normal_loading_metadata: NormalLoadingMetadata,
93        attention_mechanism: AttentionImplementation,
94    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
95    #[allow(clippy::too_many_arguments)]
96    fn load_xlora(
97        &self,
98        config: &str,
99        use_flash_attn: bool,
100        vb: ShardedVarBuilder,
101        lora_config: &[((String, String), LoraConfig)],
102        xlora_config: Option<XLoraConfig>,
103        xlora_ordering: Ordering,
104        normal_loading_metadata: NormalLoadingMetadata,
105        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
106    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
107    fn is_gptx(&self, config: &str) -> Result<bool>;
108    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
109        Ok(true)
110    }
111    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
112    fn get_device_for_tensor(
113        &self,
114        config: &str,
115        _mapper: &dyn DeviceMapper,
116        loading_isq: bool,
117    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
118        if loading_isq {
119            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
120        } else {
121            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
122            let num_layers = self.model_config(config)?.num_layers();
123            let closure = move |name: String| {
124                if let Some(captures) = re.captures(&name) {
125                    captures
126                        .get(1)
127                        .and_then(|m| m.as_str().parse::<usize>().ok())
128                        .map(|l| l.min(num_layers))
129                        .map(DeviceForLoadTensor::Idx)
130                        .unwrap_or(DeviceForLoadTensor::Base)
131                } else {
132                    DeviceForLoadTensor::Base
133                }
134            };
135
136            Ok(Arc::new(closure))
137        }
138    }
139}
140
141#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
142#[derive(Clone, Debug, Deserialize, PartialEq)]
143/// The architecture to load the normal model as.
144pub enum NormalLoaderType {
145    #[serde(rename = "mistral")]
146    Mistral,
147    #[serde(rename = "gemma")]
148    Gemma,
149    #[serde(rename = "mixtral")]
150    Mixtral,
151    #[serde(rename = "llama")]
152    Llama,
153    #[serde(rename = "phi2")]
154    Phi2,
155    #[serde(rename = "phi3")]
156    Phi3,
157    #[serde(rename = "qwen2")]
158    Qwen2,
159    #[serde(rename = "gemma2")]
160    Gemma2,
161    #[serde(rename = "starcoder2")]
162    Starcoder2,
163    #[serde(rename = "phi3.5moe")]
164    Phi3_5MoE,
165    #[serde(rename = "deepseekv2")]
166    DeepSeekV2,
167    #[serde(rename = "deepseekv3")]
168    DeepSeekV3,
169    #[serde(rename = "qwen3")]
170    Qwen3,
171    #[serde(rename = "qwen3moe")]
172    Qwen3Moe,
173}
174
175// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
176
177impl NormalLoaderType {
178    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179        match name {
180            "MistralForCausalLM" => Ok(Self::Mistral),
181            "MixtralForCausalLM" => Ok(Self::Mixtral),
182            "GemmaForCausalLM" => Ok(Self::Gemma),
183            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
184            "PhiForCausalLM" => Ok(Self::Phi2),
185            "Phi3ForCausalLM" => Ok(Self::Phi3),
186            "LlamaForCausalLM" => Ok(Self::Llama),
187            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
188            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
189            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
190            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
191            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
192            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
193            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
194            other => anyhow::bail!(
195                "Unsupported Huggging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
196            ),
197        }
198    }
199}
200
201impl FromStr for NormalLoaderType {
202    type Err = String;
203    fn from_str(s: &str) -> Result<Self, Self::Err> {
204        match s {
205            "mistral" => Ok(Self::Mistral),
206            "gemma" => Ok(Self::Gemma),
207            "mixtral" => Ok(Self::Mixtral),
208            "llama" => Ok(Self::Llama),
209            "phi2" => Ok(Self::Phi2),
210            "phi3" => Ok(Self::Phi3),
211            "qwen2" => Ok(Self::Qwen2),
212            "gemma2" => Ok(Self::Gemma2),
213            "starcoder2" => Ok(Self::Starcoder2),
214            "phi3.5moe" => Ok(Self::Phi3_5MoE),
215            "deepseekv2" => Ok(Self::DeepSeekV2),
216            "deepseekv3" => Ok(Self::DeepSeekV3),
217            "qwen3" => Ok(Self::DeepSeekV3),
218            "qwen3moe" => Ok(Self::Qwen3Moe),
219            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `qwen3moe`.")),
220        }
221    }
222}
223
224impl Display for NormalLoaderType {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Self::Gemma => write!(f, "gemma"),
228            Self::Gemma2 => write!(f, "gemma2"),
229            Self::Llama => write!(f, "llama"),
230            Self::Mistral => write!(f, "mistral"),
231            Self::Mixtral => write!(f, "mixtral"),
232            Self::Phi2 => write!(f, "phi2"),
233            Self::Phi3 => write!(f, "phi3"),
234            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
235            Self::Qwen2 => write!(f, "qwen2"),
236            Self::Starcoder2 => write!(f, "starcoder2"),
237            Self::DeepSeekV2 => write!(f, "deepseekv2"),
238            Self::DeepSeekV3 => write!(f, "deepseekv3"),
239            Self::Qwen3 => write!(f, "qwen3"),
240            Self::Qwen3Moe => write!(f, "qwen3moe"),
241        }
242    }
243}
244
245macro_rules! bias_if {
246    ($cond:expr, $size:expr) => {
247        if $cond {
248            $size
249        } else {
250            0
251        }
252    };
253}
254
255/// Load a model based on the Huggging Face Transformers -CausalLM model class
256pub struct AutoLoader;
257
258#[derive(Deserialize)]
259struct AutoLoaderConfig {
260    architectures: Vec<String>,
261}
262
263impl AutoLoader {
264    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
265        let auto_cfg: AutoLoaderConfig = serde_json::from_str(config)?;
266        if auto_cfg.architectures.len() != 1 {
267            anyhow::bail!("Expected to have one name for `architectures` config field.")
268        }
269
270        let name = &auto_cfg.architectures[0];
271
272        let tp = NormalLoaderType::from_causal_lm_name(name)?;
273
274        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
275
276        match tp {
277            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
278            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
279            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
280            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
281            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
282            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
283            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
284            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
285            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
286            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
287            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
288            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
289            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
290            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
291        }
292    }
293}
294
295impl NormalModelLoader for AutoLoader {
296    fn load(
297        &self,
298        config: &str,
299        use_flash_attn: bool,
300        vb: ShardedVarBuilder,
301        normal_loading_metadata: NormalLoadingMetadata,
302        attention_mechanism: AttentionImplementation,
303    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
304        Self::get_loader(config)?.load(
305            config,
306            use_flash_attn,
307            vb,
308            normal_loading_metadata,
309            attention_mechanism,
310        )
311    }
312    fn load_xlora(
313        &self,
314        config: &str,
315        use_flash_attn: bool,
316        vb: ShardedVarBuilder,
317        lora_config: &[((String, String), LoraConfig)],
318        xlora_config: Option<XLoraConfig>,
319        xlora_ordering: Ordering,
320        normal_loading_metadata: NormalLoadingMetadata,
321        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
322    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
323        Self::get_loader(config)?.load_xlora(
324            config,
325            use_flash_attn,
326            vb,
327            lora_config,
328            xlora_config,
329            xlora_ordering,
330            normal_loading_metadata,
331            preload_adapters,
332        )
333    }
334    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
335        Self::get_loader(config)?.get_config_repr(config, use_flash_attn)
336    }
337    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
338        Self::get_loader(config)?.supports_paged_attention(config)
339    }
340    fn is_gptx(&self, config: &str) -> Result<bool> {
341        Self::get_loader(config)?.is_gptx(config)
342    }
343}
344
345impl IsqModelLoader for AutoLoader {
346    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
347        Self::get_loader(config)?.isq_layer_regexes(config)
348    }
349}
350
351impl DeviceMappedModelLoader for AutoLoader {
352    fn non_mapped_size_in_bytes(
353        &self,
354        config: &str,
355        dtype: DType,
356        weight_pack_factor: usize,
357    ) -> Result<usize> {
358        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
359    }
360    fn num_layers(&self, config: &str) -> Result<usize> {
361        Self::get_loader(config)?.num_layers(config)
362    }
363    fn layer_sizes_in_bytes(
364        &self,
365        config: &str,
366        dtype: DType,
367        weight_pack_factor: usize,
368    ) -> Result<Vec<usize>> {
369        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
370    }
371    fn mapped_max_act_size_elems(
372        &self,
373        config: &str,
374        params: &super::AutoDeviceMapParams,
375        prompt_chunksize: usize,
376    ) -> Result<usize> {
377        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
378    }
379    fn non_mapped_max_act_size_elems(
380        &self,
381        _config: &str,
382        _params: &AutoDeviceMapParams,
383    ) -> Result<usize> {
384        Ok(0)
385    }
386    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
387        Self::get_loader(config)?.model_config(config)
388    }
389}
390
391serde_default_fn!(bool, word_emb_default, false);
392
393// ======================== Mistral loader
394
395#[derive(Deserialize, Debug)]
396struct MistralBasicConfig {
397    vocab_size: usize,
398    hidden_size: usize,
399    intermediate_size: usize,
400    num_hidden_layers: usize,
401    num_attention_heads: usize,
402    num_key_value_heads: usize,
403    hidden_act: Activation,
404    max_position_embeddings: usize,
405    rms_norm_eps: f64,
406    rope_theta: f64,
407    sliding_window: Option<usize>,
408    head_dim: Option<usize>,
409    quantization_config: Option<QuantizedConfig>,
410    #[serde(default = "word_emb_default")]
411    tie_word_embeddings: bool,
412}
413
414impl MistralBasicConfig {
415    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mistral::Config> {
416        let basic_config: Self = serde_json::from_str(slice)?;
417        Ok(models::mistral::Config {
418            vocab_size: basic_config.vocab_size,
419            hidden_size: basic_config.hidden_size,
420            intermediate_size: basic_config.intermediate_size,
421            num_hidden_layers: basic_config.num_hidden_layers,
422            num_attention_heads: basic_config.num_attention_heads,
423            num_key_value_heads: basic_config.num_key_value_heads,
424            hidden_act: basic_config.hidden_act,
425            max_position_embeddings: basic_config.max_position_embeddings,
426            rms_norm_eps: basic_config.rms_norm_eps,
427            rope_theta: basic_config.rope_theta,
428            sliding_window: basic_config.sliding_window,
429            use_flash_attn,
430            head_dim: basic_config.head_dim,
431            quantization_config: basic_config.quantization_config,
432            tie_word_embeddings: basic_config.tie_word_embeddings,
433        })
434    }
435}
436
437pub struct MistralLoader;
438
439impl NormalModelLoader for MistralLoader {
440    fn load(
441        &self,
442        config: &str,
443        use_flash_attn: bool,
444        vb: ShardedVarBuilder,
445        normal_loading_metadata: NormalLoadingMetadata,
446        attention_mechanism: AttentionImplementation,
447    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
448        Ok(Box::new(models::mistral::Model::new(
449            &MistralBasicConfig::deserialize(config, use_flash_attn)?,
450            vb,
451            self.is_gptx(config)?,
452            normal_loading_metadata,
453            attention_mechanism,
454        )?))
455    }
456    fn load_xlora(
457        &self,
458        config: &str,
459        use_flash_attn: bool,
460        vb: ShardedVarBuilder,
461        lora_config: &[((String, String), LoraConfig)],
462        xlora_config: Option<XLoraConfig>,
463        xlora_ordering: Ordering,
464        normal_loading_metadata: NormalLoadingMetadata,
465        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
466    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
467        Ok(Box::new(xlora_models::XLoraMistral::new(
468            &MistralBasicConfig::deserialize(config, use_flash_attn)?,
469            vb,
470            lora_config,
471            xlora_config,
472            xlora_ordering,
473            self.is_gptx(config)?,
474            normal_loading_metadata,
475            preload_adapters,
476        )?))
477    }
478    fn is_gptx(&self, _: &str) -> Result<bool> {
479        Ok(true)
480    }
481    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
482        Ok(Box::new(MistralBasicConfig::deserialize(
483            config,
484            use_flash_attn,
485        )?))
486    }
487}
488
489impl IsqModelLoader for MistralLoader {
490    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
491        Ok(vec![
492            Regex::new(r"lm_head\.(weight|bias)$")?,
493            // Attention
494            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
495            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
496            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
497            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
498            // MLP
499            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
500            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
501            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
502        ])
503    }
504}
505
506impl DeviceMappedModelLoader for MistralLoader {
507    fn mapped_max_act_size_elems(
508        &self,
509        config: &str,
510        params: &AutoDeviceMapParams,
511        prompt_chunksize: usize,
512    ) -> Result<usize> {
513        let AutoDeviceMapParams::Text {
514            max_seq_len: _,
515            max_batch_size,
516        } = params
517        else {
518            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
519        };
520
521        let cfg = MistralBasicConfig::deserialize(config, false)?;
522
523        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
524    }
525    fn non_mapped_max_act_size_elems(
526        &self,
527        _config: &str,
528        _params: &AutoDeviceMapParams,
529    ) -> Result<usize> {
530        Ok(0)
531    }
532
533    fn non_mapped_size_in_bytes(
534        &self,
535        config: &str,
536        dtype: DType,
537        weight_pack_factor: usize,
538    ) -> Result<usize> {
539        let cfg = MistralBasicConfig::deserialize(config, false)?;
540        let elems = {
541            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
542            let lm_head = if !cfg.tie_word_embeddings {
543                cfg.hidden_size * cfg.vocab_size
544            } else {
545                0
546            };
547            let norm = cfg.hidden_size;
548            embed_tokens + lm_head + norm
549        };
550        Ok(elems * dtype.size_in_bytes())
551    }
552
553    fn layer_sizes_in_bytes(
554        &self,
555        config: &str,
556        dtype: DType,
557        weight_pack_factor: usize,
558    ) -> Result<Vec<usize>> {
559        let cfg = MistralBasicConfig::deserialize(config, false)?;
560        let per_layer_elems = {
561            let input_layernorm = cfg.hidden_size;
562            let post_attention_layernorm = cfg.hidden_size;
563
564            let size_in = cfg.hidden_size;
565            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
566            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
567            let q_proj = size_in * size_q / weight_pack_factor;
568            let k_proj = size_in * size_kv / weight_pack_factor;
569            let v_proj = size_in * size_kv / weight_pack_factor;
570            let o_proj = size_q * size_in / weight_pack_factor;
571
572            let h_size = cfg.hidden_size;
573            let i_size = cfg.intermediate_size;
574            let gate_proj = h_size * i_size / weight_pack_factor;
575            let up_proj = h_size * i_size / weight_pack_factor;
576            let down_proj = i_size * h_size / weight_pack_factor;
577
578            input_layernorm
579                + post_attention_layernorm
580                + q_proj
581                + k_proj
582                + v_proj
583                + o_proj
584                + gate_proj
585                + up_proj
586                + down_proj
587        };
588        Ok(vec![
589            per_layer_elems * dtype.size_in_bytes();
590            cfg.num_hidden_layers
591        ])
592    }
593
594    fn num_layers(&self, config: &str) -> Result<usize> {
595        let cfg = MistralBasicConfig::deserialize(config, false)?;
596        Ok(cfg.num_hidden_layers)
597    }
598
599    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
600        let cfg = MistralBasicConfig::deserialize(config, false)?;
601
602        let cfg = ModelConfigMetadata {
603            max_seq_len: cfg.max_position_embeddings,
604            num_layers: cfg.num_hidden_layers,
605            hidden_size: cfg.hidden_size,
606            num_kv_heads: cfg.num_key_value_heads,
607            num_attn_heads: cfg.num_attention_heads,
608            sliding_window: cfg.sliding_window,
609            k_head_dim: cfg.head_dim(),
610            v_head_dim: cfg.head_dim(),
611        };
612
613        Ok(Box::new(cfg))
614    }
615}
616
617// ======================== Gemma loader
618
619fn default_max_position_embeddings() -> usize {
620    4096
621}
622
623#[derive(Deserialize)]
624struct GemmaBasicConfig {
625    attention_bias: bool,
626    head_dim: usize,
627    // The code gemma configs include both hidden_act and hidden_activation.
628    hidden_act: Option<Activation>,
629    hidden_activation: Option<Activation>,
630    hidden_size: usize,
631    intermediate_size: usize,
632    num_attention_heads: usize,
633    num_hidden_layers: usize,
634    num_key_value_heads: usize,
635    rms_norm_eps: f64,
636    rope_theta: f64,
637    vocab_size: usize,
638
639    #[serde(default = "default_max_position_embeddings")]
640    max_position_embeddings: usize,
641    quantization_config: Option<QuantizedConfig>,
642    #[serde(default = "word_emb_default")]
643    tie_word_embeddings: bool,
644}
645
646impl GemmaBasicConfig {
647    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma::Config> {
648        let basic_config: Self = serde_json::from_str(slice)?;
649        Ok(models::gemma::Config {
650            vocab_size: basic_config.vocab_size,
651            hidden_size: basic_config.hidden_size,
652            intermediate_size: basic_config.intermediate_size,
653            num_hidden_layers: basic_config.num_hidden_layers,
654            num_attention_heads: basic_config.num_attention_heads,
655            num_key_value_heads: basic_config.num_key_value_heads,
656            hidden_act: basic_config.hidden_act,
657            hidden_activation: basic_config.hidden_activation,
658            max_position_embeddings: basic_config.max_position_embeddings,
659            rms_norm_eps: basic_config.rms_norm_eps,
660            rope_theta: basic_config.rope_theta,
661            attention_bias: basic_config.attention_bias,
662            head_dim: basic_config.head_dim,
663            use_flash_attn,
664            quantization_config: basic_config.quantization_config,
665            tie_word_embeddings: basic_config.tie_word_embeddings,
666        })
667    }
668}
669
670/// [`NormalLoader`] for a Gemma model.
671///
672/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
673pub struct GemmaLoader;
674
675impl NormalModelLoader for GemmaLoader {
676    fn load(
677        &self,
678        config: &str,
679        use_flash_attn: bool,
680        vb: ShardedVarBuilder,
681        normal_loading_metadata: NormalLoadingMetadata,
682        attention_mechanism: AttentionImplementation,
683    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
684        Ok(Box::new(models::gemma::Model::new(
685            &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
686            vb,
687            self.is_gptx(config)?,
688            normal_loading_metadata,
689            attention_mechanism,
690        )?))
691    }
692    fn load_xlora(
693        &self,
694        config: &str,
695        use_flash_attn: bool,
696        vb: ShardedVarBuilder,
697        lora_config: &[((String, String), LoraConfig)],
698        xlora_config: Option<XLoraConfig>,
699        xlora_ordering: Ordering,
700        normal_loading_metadata: NormalLoadingMetadata,
701        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
702    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
703        Ok(Box::new(xlora_models::XLoraGemma::new(
704            &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
705            vb,
706            lora_config,
707            xlora_config,
708            xlora_ordering,
709            self.is_gptx(config)?,
710            normal_loading_metadata,
711            preload_adapters,
712        )?))
713    }
714    fn is_gptx(&self, _: &str) -> Result<bool> {
715        Ok(true)
716    }
717    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
718        Ok(Box::new(GemmaBasicConfig::deserialize(
719            config,
720            use_flash_attn,
721        )?))
722    }
723}
724
725impl IsqModelLoader for GemmaLoader {
726    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
727        Ok(vec![
728            Regex::new(r"lm_head\.(weight|bias)$")?,
729            // Attention
730            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
731            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
732            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
733            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
734            // MLP
735            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
736            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
737            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
738        ])
739    }
740}
741
742impl DeviceMappedModelLoader for GemmaLoader {
743    fn mapped_max_act_size_elems(
744        &self,
745        config: &str,
746        params: &AutoDeviceMapParams,
747        prompt_chunksize: usize,
748    ) -> Result<usize> {
749        let AutoDeviceMapParams::Text {
750            max_seq_len: _,
751            max_batch_size,
752        } = params
753        else {
754            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
755        };
756
757        let cfg = GemmaBasicConfig::deserialize(config, false)?;
758
759        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
760    }
761    fn non_mapped_max_act_size_elems(
762        &self,
763        _config: &str,
764        _params: &AutoDeviceMapParams,
765    ) -> Result<usize> {
766        Ok(0)
767    }
768
769    fn non_mapped_size_in_bytes(
770        &self,
771        config: &str,
772        dtype: DType,
773        weight_pack_factor: usize,
774    ) -> Result<usize> {
775        let cfg = GemmaBasicConfig::deserialize(config, false)?;
776        let elems = {
777            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
778            let lm_head = if !cfg.tie_word_embeddings {
779                cfg.hidden_size * cfg.vocab_size
780            } else {
781                0
782            };
783            let norm = cfg.hidden_size;
784            embed_tokens + lm_head + norm
785        };
786        Ok(elems * dtype.size_in_bytes())
787    }
788
789    fn layer_sizes_in_bytes(
790        &self,
791        config: &str,
792        dtype: DType,
793        weight_pack_factor: usize,
794    ) -> Result<Vec<usize>> {
795        let cfg = GemmaBasicConfig::deserialize(config, false)?;
796        let per_layer_elems = {
797            let input_layernorm = cfg.hidden_size;
798            let post_attention_layernorm = cfg.hidden_size;
799
800            let size_in = cfg.hidden_size;
801            let size_q = cfg.head_dim * cfg.num_attention_heads;
802            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
803            let q_proj =
804                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
805            let k_proj =
806                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
807            let v_proj =
808                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
809            let o_proj =
810                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
811
812            let h_size = cfg.hidden_size;
813            let i_size = cfg.intermediate_size;
814            let gate_proj = h_size * i_size / weight_pack_factor;
815            let up_proj = h_size * i_size / weight_pack_factor;
816            let down_proj = i_size * h_size / weight_pack_factor;
817
818            input_layernorm
819                + post_attention_layernorm
820                + q_proj
821                + k_proj
822                + v_proj
823                + o_proj
824                + gate_proj
825                + up_proj
826                + down_proj
827        };
828        Ok(vec![
829            per_layer_elems * dtype.size_in_bytes();
830            cfg.num_hidden_layers
831        ])
832    }
833
834    fn num_layers(&self, config: &str) -> Result<usize> {
835        let cfg = GemmaBasicConfig::deserialize(config, false)?;
836        Ok(cfg.num_hidden_layers)
837    }
838
839    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
840        let cfg = GemmaBasicConfig::deserialize(config, false)?;
841
842        let cfg = ModelConfigMetadata {
843            max_seq_len: cfg.max_position_embeddings,
844            num_layers: cfg.num_hidden_layers,
845            hidden_size: cfg.hidden_size,
846            num_kv_heads: cfg.num_key_value_heads,
847            num_attn_heads: cfg.num_attention_heads,
848            sliding_window: None,
849            k_head_dim: cfg.head_dim,
850            v_head_dim: cfg.head_dim,
851        };
852
853        Ok(Box::new(cfg))
854    }
855}
856
857// ======================== Llama loader
858
859#[derive(Deserialize)]
860struct LlamaBasicConfig {
861    hidden_act: Activation,
862    hidden_size: usize,
863    intermediate_size: usize,
864    vocab_size: usize,
865    num_hidden_layers: usize,
866    num_attention_heads: usize,
867    num_key_value_heads: Option<usize>,
868    rms_norm_eps: f64,
869    #[serde(default = "default_rope")]
870    rope_theta: f32,
871    max_position_embeddings: usize,
872    rope_scaling: Option<Llama3RopeConfig>,
873    quantization_config: Option<QuantizedConfig>,
874    #[serde(default = "word_emb_default")]
875    tie_word_embeddings: bool,
876}
877
878fn default_rope() -> f32 {
879    10_000.0
880}
881
882impl LlamaBasicConfig {
883    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::llama::Config> {
884        let basic_config: Self = serde_json::from_str(slice)?;
885        Ok(models::llama::Config {
886            hidden_size: basic_config.hidden_size,
887            intermediate_size: basic_config.intermediate_size,
888            vocab_size: basic_config.vocab_size,
889            num_hidden_layers: basic_config.num_hidden_layers,
890            num_attention_heads: basic_config.num_attention_heads,
891            num_key_value_heads: basic_config
892                .num_key_value_heads
893                .unwrap_or(basic_config.num_attention_heads),
894            rms_norm_eps: basic_config.rms_norm_eps,
895            rope_theta: basic_config.rope_theta,
896            use_flash_attn,
897            max_position_embeddings: basic_config.max_position_embeddings,
898            rope_scaling: basic_config.rope_scaling,
899            quantization_config: basic_config.quantization_config,
900            tie_word_embeddings: basic_config.tie_word_embeddings,
901            hidden_act: basic_config.hidden_act,
902        })
903    }
904}
905
906/// [`NormalLoader`] for a Llama model.
907///
908/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
909pub struct LlamaLoader;
910
911impl NormalModelLoader for LlamaLoader {
912    fn load(
913        &self,
914        config: &str,
915        use_flash_attn: bool,
916        vb: ShardedVarBuilder,
917        normal_loading_metadata: NormalLoadingMetadata,
918        attention_mechanism: AttentionImplementation,
919    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
920        Ok(Box::new(models::llama::Llama::new(
921            &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
922            vb,
923            self.is_gptx(config)?,
924            normal_loading_metadata,
925            attention_mechanism,
926        )?))
927    }
928    fn load_xlora(
929        &self,
930        config: &str,
931        use_flash_attn: bool,
932        vb: ShardedVarBuilder,
933        lora_config: &[((String, String), LoraConfig)],
934        xlora_config: Option<XLoraConfig>,
935        xlora_ordering: Ordering,
936        normal_loading_metadata: NormalLoadingMetadata,
937        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
938    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
939        Ok(Box::new(xlora_models::XLoraLlama::new(
940            &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
941            vb,
942            lora_config,
943            xlora_config,
944            xlora_ordering,
945            self.is_gptx(config)?,
946            normal_loading_metadata,
947            preload_adapters,
948        )?))
949    }
950    fn is_gptx(&self, _: &str) -> Result<bool> {
951        Ok(true)
952    }
953    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
954        Ok(Box::new(LlamaBasicConfig::deserialize(
955            config,
956            use_flash_attn,
957        )?))
958    }
959}
960
961impl IsqModelLoader for LlamaLoader {
962    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
963        Ok(vec![
964            Regex::new(r"lm_head\.(weight|bias)$")?,
965            // Attention
966            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
967            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
968            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
969            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
970            // MLP
971            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
972            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
973            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
974        ])
975    }
976}
977
978impl DeviceMappedModelLoader for LlamaLoader {
979    fn mapped_max_act_size_elems(
980        &self,
981        config: &str,
982        params: &AutoDeviceMapParams,
983        prompt_chunksize: usize,
984    ) -> Result<usize> {
985        let AutoDeviceMapParams::Text {
986            max_seq_len: _,
987            max_batch_size,
988        } = params
989        else {
990            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
991        };
992
993        let cfg = LlamaBasicConfig::deserialize(config, false)?;
994
995        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
996    }
997    fn non_mapped_max_act_size_elems(
998        &self,
999        _config: &str,
1000        _params: &AutoDeviceMapParams,
1001    ) -> Result<usize> {
1002        Ok(0)
1003    }
1004
1005    fn non_mapped_size_in_bytes(
1006        &self,
1007        config: &str,
1008        dtype: DType,
1009        weight_pack_factor: usize,
1010    ) -> Result<usize> {
1011        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1012        let elems = {
1013            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1014            let lm_head = if !cfg.tie_word_embeddings {
1015                cfg.hidden_size * cfg.vocab_size
1016            } else {
1017                0
1018            };
1019            let norm = cfg.hidden_size;
1020            embed_tokens + lm_head + norm
1021        };
1022        Ok(elems * dtype.size_in_bytes())
1023    }
1024
1025    fn layer_sizes_in_bytes(
1026        &self,
1027        config: &str,
1028        dtype: DType,
1029        weight_pack_factor: usize,
1030    ) -> Result<Vec<usize>> {
1031        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1032        let per_layer_elems = {
1033            let input_layernorm = cfg.hidden_size;
1034            let post_attention_layernorm = cfg.hidden_size;
1035
1036            let size_in = cfg.hidden_size;
1037            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1038            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1039            let q_proj = size_in * size_q / weight_pack_factor;
1040            let k_proj = size_in * size_kv / weight_pack_factor;
1041            let v_proj = size_in * size_kv / weight_pack_factor;
1042            let o_proj = size_q * size_in / weight_pack_factor;
1043
1044            let h_size = cfg.hidden_size;
1045            let i_size = cfg.intermediate_size;
1046            let gate_proj = h_size * i_size / weight_pack_factor;
1047            let up_proj = h_size * i_size / weight_pack_factor;
1048            let down_proj = i_size * h_size / weight_pack_factor;
1049
1050            input_layernorm
1051                + post_attention_layernorm
1052                + q_proj
1053                + k_proj
1054                + v_proj
1055                + o_proj
1056                + gate_proj
1057                + up_proj
1058                + down_proj
1059        };
1060        Ok(vec![
1061            per_layer_elems * dtype.size_in_bytes();
1062            cfg.num_hidden_layers
1063        ])
1064    }
1065
1066    fn num_layers(&self, config: &str) -> Result<usize> {
1067        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1068        Ok(cfg.num_hidden_layers)
1069    }
1070    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1071        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1072
1073        let cfg = ModelConfigMetadata {
1074            max_seq_len: cfg.max_position_embeddings,
1075            num_layers: cfg.num_hidden_layers,
1076            hidden_size: cfg.hidden_size,
1077            num_kv_heads: cfg.num_key_value_heads,
1078            num_attn_heads: cfg.num_attention_heads,
1079            sliding_window: None,
1080            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1081            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1082        };
1083
1084        Ok(Box::new(cfg))
1085    }
1086}
1087
1088// ======================== Mixtral loader
1089
1090#[derive(Deserialize)]
1091struct MixtralBasicConfig {
1092    vocab_size: usize,
1093    hidden_size: usize,
1094    intermediate_size: usize,
1095    num_hidden_layers: usize,
1096    num_attention_heads: usize,
1097    num_key_value_heads: usize,
1098    hidden_act: Activation,
1099    max_position_embeddings: usize,
1100    rms_norm_eps: f64,
1101    rope_theta: f64,
1102    sliding_window: Option<usize>,
1103    num_experts_per_tok: usize,
1104    num_local_experts: usize,
1105    quantization_config: Option<QuantizedConfig>,
1106    #[serde(default = "word_emb_default")]
1107    tie_word_embeddings: bool,
1108}
1109
1110impl MixtralBasicConfig {
1111    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mixtral::Config> {
1112        let basic_config: Self = serde_json::from_str(slice)?;
1113        Ok(models::mixtral::Config {
1114            vocab_size: basic_config.vocab_size,
1115            hidden_size: basic_config.hidden_size,
1116            intermediate_size: basic_config.intermediate_size,
1117            num_hidden_layers: basic_config.num_hidden_layers,
1118            num_attention_heads: basic_config.num_attention_heads,
1119            num_key_value_heads: basic_config.num_key_value_heads,
1120            hidden_act: basic_config.hidden_act,
1121            max_position_embeddings: basic_config.max_position_embeddings,
1122            rms_norm_eps: basic_config.rms_norm_eps,
1123            rope_theta: basic_config.rope_theta,
1124            sliding_window: basic_config.sliding_window,
1125            use_flash_attn,
1126            num_experts_per_tok: basic_config.num_experts_per_tok,
1127            num_local_experts: basic_config.num_local_experts,
1128            quantization_config: basic_config.quantization_config,
1129            tie_word_embeddings: basic_config.tie_word_embeddings,
1130        })
1131    }
1132}
1133
1134pub struct MixtralLoader;
1135
1136impl NormalModelLoader for MixtralLoader {
1137    fn load(
1138        &self,
1139        config: &str,
1140        use_flash_attn: bool,
1141        vb: ShardedVarBuilder,
1142        normal_loading_metadata: NormalLoadingMetadata,
1143        attention_mechanism: AttentionImplementation,
1144    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1145        Ok(Box::new(models::mixtral::Model::new(
1146            &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1147            vb,
1148            self.is_gptx(config)?,
1149            normal_loading_metadata,
1150            attention_mechanism,
1151        )?))
1152    }
1153    fn load_xlora(
1154        &self,
1155        config: &str,
1156        use_flash_attn: bool,
1157        vb: ShardedVarBuilder,
1158        lora_config: &[((String, String), LoraConfig)],
1159        xlora_config: Option<XLoraConfig>,
1160        xlora_ordering: Ordering,
1161        normal_loading_metadata: NormalLoadingMetadata,
1162        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1163    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1164        Ok(Box::new(xlora_models::XLoraMixtral::new(
1165            &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1166            vb,
1167            lora_config,
1168            xlora_config,
1169            xlora_ordering,
1170            self.is_gptx(config)?,
1171            normal_loading_metadata,
1172            preload_adapters,
1173        )?))
1174    }
1175    fn is_gptx(&self, _: &str) -> Result<bool> {
1176        Ok(true)
1177    }
1178    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1179        Ok(Box::new(MixtralBasicConfig::deserialize(
1180            config,
1181            use_flash_attn,
1182        )?))
1183    }
1184}
1185
1186impl IsqModelLoader for MixtralLoader {
1187    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1188        Ok(vec![
1189            Regex::new(r"lm_head\.(weight|bias)$")?,
1190            // Attention
1191            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1192            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1193            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1194            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1195            // Experts
1196            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1197            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1198            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1199            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1200        ])
1201    }
1202}
1203
1204impl DeviceMappedModelLoader for MixtralLoader {
1205    fn mapped_max_act_size_elems(
1206        &self,
1207        config: &str,
1208        params: &AutoDeviceMapParams,
1209        prompt_chunksize: usize,
1210    ) -> Result<usize> {
1211        let AutoDeviceMapParams::Text {
1212            max_seq_len: _,
1213            max_batch_size,
1214        } = params
1215        else {
1216            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1217        };
1218
1219        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1220
1221        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1222    }
1223    fn non_mapped_max_act_size_elems(
1224        &self,
1225        _config: &str,
1226        _params: &AutoDeviceMapParams,
1227    ) -> Result<usize> {
1228        Ok(0)
1229    }
1230
1231    fn non_mapped_size_in_bytes(
1232        &self,
1233        config: &str,
1234        dtype: DType,
1235        weight_pack_factor: usize,
1236    ) -> Result<usize> {
1237        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1238        let elems = {
1239            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1240            let lm_head = if !cfg.tie_word_embeddings {
1241                cfg.hidden_size * cfg.vocab_size
1242            } else {
1243                0
1244            };
1245            let norm = cfg.hidden_size;
1246            embed_tokens + lm_head + norm
1247        };
1248        Ok(elems * dtype.size_in_bytes())
1249    }
1250
1251    fn layer_sizes_in_bytes(
1252        &self,
1253        config: &str,
1254        dtype: DType,
1255        weight_pack_factor: usize,
1256    ) -> Result<Vec<usize>> {
1257        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1258        let per_layer_elems = {
1259            let input_layernorm = cfg.hidden_size;
1260            let post_attention_layernorm = cfg.hidden_size;
1261
1262            let size_in = cfg.hidden_size;
1263            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1264            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1265            let q_proj = size_in * size_q / weight_pack_factor;
1266            let k_proj = size_in * size_kv / weight_pack_factor;
1267            let v_proj = size_in * size_kv / weight_pack_factor;
1268            let o_proj = size_q * size_in / weight_pack_factor;
1269
1270            let moe_block = {
1271                let gate = cfg.hidden_size * cfg.num_local_experts;
1272                // Assume quantizing weight pack factor
1273                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1274                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1275                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1276                gate + cfg.num_local_experts * w1
1277                    + cfg.num_local_experts * w2
1278                    + cfg.num_local_experts * w3
1279            };
1280
1281            input_layernorm
1282                + post_attention_layernorm
1283                + q_proj
1284                + k_proj
1285                + v_proj
1286                + o_proj
1287                + moe_block
1288        };
1289        Ok(vec![
1290            per_layer_elems * dtype.size_in_bytes();
1291            cfg.num_hidden_layers
1292        ])
1293    }
1294
1295    fn num_layers(&self, config: &str) -> Result<usize> {
1296        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1297        Ok(cfg.num_hidden_layers)
1298    }
1299
1300    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1301        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1302
1303        let cfg = ModelConfigMetadata {
1304            max_seq_len: cfg.max_position_embeddings,
1305            num_layers: cfg.num_hidden_layers,
1306            hidden_size: cfg.hidden_size,
1307            num_kv_heads: cfg.num_key_value_heads,
1308            num_attn_heads: cfg.num_attention_heads,
1309            sliding_window: cfg.sliding_window,
1310            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1311            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1312        };
1313
1314        Ok(Box::new(cfg))
1315    }
1316}
1317
1318// ======================== Phi2 loader
1319
1320#[derive(Deserialize)]
1321struct Phi2BasicConfig {
1322    vocab_size: usize,
1323    hidden_size: usize,
1324    intermediate_size: usize,
1325    num_hidden_layers: usize,
1326    num_attention_heads: usize,
1327    num_key_value_heads: Option<usize>,
1328    hidden_act: Activation,
1329    max_position_embeddings: usize,
1330    layer_norm_eps: f64,
1331    rope_theta: f32,
1332    partial_rotary_factor: f64,
1333    qk_layernorm: bool,
1334    quantization_config: Option<QuantizedConfig>,
1335    #[serde(default = "word_emb_default")]
1336    tie_word_embeddings: bool,
1337}
1338
1339impl Phi2BasicConfig {
1340    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi2::Config> {
1341        let basic_config: Self = serde_json::from_str(slice)?;
1342        Ok(models::phi2::Config {
1343            vocab_size: basic_config.vocab_size,
1344            hidden_size: basic_config.hidden_size,
1345            intermediate_size: basic_config.intermediate_size,
1346            num_hidden_layers: basic_config.num_hidden_layers,
1347            num_attention_heads: basic_config.num_attention_heads,
1348            num_key_value_heads: basic_config.num_key_value_heads,
1349            hidden_act: basic_config.hidden_act,
1350            max_position_embeddings: basic_config.max_position_embeddings,
1351            rope_theta: basic_config.rope_theta,
1352            layer_norm_eps: basic_config.layer_norm_eps,
1353            partial_rotary_factor: basic_config.partial_rotary_factor,
1354            qk_layernorm: basic_config.qk_layernorm,
1355            use_flash_attn,
1356            quantization_config: basic_config.quantization_config,
1357            tie_word_embeddings: basic_config.tie_word_embeddings,
1358        })
1359    }
1360}
1361
1362/// [`NormalLoader`] for a Phi 2 model.
1363///
1364/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1365pub struct Phi2Loader;
1366
1367impl NormalModelLoader for Phi2Loader {
1368    fn load(
1369        &self,
1370        config: &str,
1371        use_flash_attn: bool,
1372        vb: ShardedVarBuilder,
1373        normal_loading_metadata: NormalLoadingMetadata,
1374        attention_mechanism: AttentionImplementation,
1375    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1376        Ok(Box::new(models::phi2::Model::new(
1377            &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1378            vb,
1379            self.is_gptx(config)?,
1380            normal_loading_metadata,
1381            attention_mechanism,
1382        )?))
1383    }
1384    fn load_xlora(
1385        &self,
1386        config: &str,
1387        use_flash_attn: bool,
1388        vb: ShardedVarBuilder,
1389        lora_config: &[((String, String), LoraConfig)],
1390        xlora_config: Option<XLoraConfig>,
1391        xlora_ordering: Ordering,
1392        normal_loading_metadata: NormalLoadingMetadata,
1393        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1394    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1395        Ok(Box::new(xlora_models::XLoraPhi2::new(
1396            &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1397            vb,
1398            lora_config,
1399            xlora_config,
1400            xlora_ordering,
1401            self.is_gptx(config)?,
1402            normal_loading_metadata,
1403            preload_adapters,
1404        )?))
1405    }
1406    fn is_gptx(&self, _: &str) -> Result<bool> {
1407        Ok(true)
1408    }
1409    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1410        Ok(Box::new(Phi2BasicConfig::deserialize(
1411            config,
1412            use_flash_attn,
1413        )?))
1414    }
1415}
1416
1417impl IsqModelLoader for Phi2Loader {
1418    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1419        Ok(vec![
1420            Regex::new(r"lm_head\.(weight|bias)$")?,
1421            // Attention
1422            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1423            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1424            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1425            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1426            // MLP
1427            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1428            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1429        ])
1430    }
1431}
1432
1433impl DeviceMappedModelLoader for Phi2Loader {
1434    fn mapped_max_act_size_elems(
1435        &self,
1436        config: &str,
1437        params: &AutoDeviceMapParams,
1438        prompt_chunksize: usize,
1439    ) -> Result<usize> {
1440        let AutoDeviceMapParams::Text {
1441            max_seq_len: _,
1442            max_batch_size,
1443        } = params
1444        else {
1445            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1446        };
1447
1448        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1449
1450        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1451    }
1452    fn non_mapped_max_act_size_elems(
1453        &self,
1454        _config: &str,
1455        _params: &AutoDeviceMapParams,
1456    ) -> Result<usize> {
1457        Ok(0)
1458    }
1459
1460    fn non_mapped_size_in_bytes(
1461        &self,
1462        config: &str,
1463        dtype: DType,
1464        weight_pack_factor: usize,
1465    ) -> Result<usize> {
1466        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1467        let elems = {
1468            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1469            let lm_head = if !cfg.tie_word_embeddings {
1470                cfg.hidden_size * cfg.vocab_size
1471            } else {
1472                0
1473            };
1474            let norm = cfg.hidden_size;
1475            embed_tokens + lm_head + norm
1476        };
1477        Ok(elems * dtype.size_in_bytes())
1478    }
1479
1480    fn layer_sizes_in_bytes(
1481        &self,
1482        config: &str,
1483        dtype: DType,
1484        weight_pack_factor: usize,
1485    ) -> Result<Vec<usize>> {
1486        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1487        let per_layer_elems = {
1488            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1489
1490            let size_in = cfg.hidden_size;
1491            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1492            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1493            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1494            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1495            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1496            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1497            let (q_norm, k_norm) = if cfg.qk_layernorm {
1498                (cfg.head_dim(), cfg.head_dim())
1499            } else {
1500                (0, 0)
1501            };
1502
1503            let h_size = cfg.hidden_size;
1504            let i_size = cfg.intermediate_size;
1505            let fc1 = h_size * i_size / weight_pack_factor;
1506            let fc2 = h_size * i_size / weight_pack_factor;
1507
1508            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1509        };
1510        Ok(vec![
1511            per_layer_elems * dtype.size_in_bytes();
1512            cfg.num_hidden_layers
1513        ])
1514    }
1515
1516    fn num_layers(&self, config: &str) -> Result<usize> {
1517        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1518        Ok(cfg.num_hidden_layers)
1519    }
1520
1521    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1522        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1523
1524        let cfg = ModelConfigMetadata {
1525            max_seq_len: cfg.max_position_embeddings,
1526            num_layers: cfg.num_hidden_layers,
1527            hidden_size: cfg.hidden_size,
1528            num_kv_heads: cfg.num_key_value_heads(),
1529            num_attn_heads: cfg.num_attention_heads,
1530            sliding_window: None,
1531            k_head_dim: cfg.head_dim(),
1532            v_head_dim: cfg.head_dim(),
1533        };
1534
1535        Ok(Box::new(cfg))
1536    }
1537}
1538
1539// ======================== Phi3 loader
1540
1541#[derive(Deserialize)]
1542struct Phi3BasicConfig {
1543    vocab_size: usize,
1544    hidden_act: Activation,
1545    hidden_size: usize,
1546    intermediate_size: usize,
1547    num_hidden_layers: usize,
1548    num_attention_heads: usize,
1549    num_key_value_heads: usize,
1550    rms_norm_eps: f64,
1551    rope_theta: f64,
1552    bos_token_id: Option<u32>,
1553    eos_token_id: Option<u32>,
1554    rope_scaling: Option<PhiRopeScalingConfig>,
1555    max_position_embeddings: usize,
1556    original_max_position_embeddings: usize,
1557    sliding_window: Option<usize>,
1558    quantization_config: Option<QuantizedConfig>,
1559    #[serde(default = "word_emb_default")]
1560    tie_word_embeddings: bool,
1561    partial_rotary_factor: Option<f64>,
1562}
1563
1564impl Phi3BasicConfig {
1565    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3::Config> {
1566        let basic_config: Self = serde_json::from_str(slice)?;
1567        Ok(models::phi3::Config {
1568            vocab_size: basic_config.vocab_size,
1569            hidden_size: basic_config.hidden_size,
1570            intermediate_size: basic_config.intermediate_size,
1571            num_hidden_layers: basic_config.num_hidden_layers,
1572            num_attention_heads: basic_config.num_attention_heads,
1573            num_key_value_heads: basic_config.num_key_value_heads,
1574            hidden_act: basic_config.hidden_act,
1575            max_position_embeddings: basic_config.max_position_embeddings,
1576            rope_theta: basic_config.rope_theta,
1577            rms_norm_eps: basic_config.rms_norm_eps,
1578            eos_token_id: basic_config.eos_token_id,
1579            bos_token_id: basic_config.bos_token_id,
1580            rope_scaling: basic_config.rope_scaling,
1581            original_max_position_embeddings: basic_config.original_max_position_embeddings,
1582            use_flash_attn,
1583            sliding_window: basic_config.sliding_window,
1584            quantization_config: basic_config.quantization_config,
1585            tie_word_embeddings: basic_config.tie_word_embeddings,
1586            partial_rotary_factor: basic_config.partial_rotary_factor,
1587        })
1588    }
1589}
1590
1591/// [`NormalLoader`] for a Phi 3 model.
1592///
1593/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1594pub struct Phi3Loader;
1595
1596impl NormalModelLoader for Phi3Loader {
1597    fn load(
1598        &self,
1599        config: &str,
1600        use_flash_attn: bool,
1601        vb: ShardedVarBuilder,
1602        normal_loading_metadata: NormalLoadingMetadata,
1603        attention_mechanism: AttentionImplementation,
1604    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1605        Ok(Box::new(models::phi3::Model::new(
1606            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1607            vb,
1608            self.is_gptx(config)?,
1609            normal_loading_metadata,
1610            attention_mechanism,
1611        )?))
1612    }
1613    fn load_xlora(
1614        &self,
1615        config: &str,
1616        use_flash_attn: bool,
1617        vb: ShardedVarBuilder,
1618        lora_config: &[((String, String), LoraConfig)],
1619        xlora_config: Option<XLoraConfig>,
1620        xlora_ordering: Ordering,
1621        normal_loading_metadata: NormalLoadingMetadata,
1622        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1623    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1624        Ok(Box::new(xlora_models::XLoraPhi3::new(
1625            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1626            vb,
1627            lora_config,
1628            xlora_config,
1629            xlora_ordering,
1630            self.is_gptx(config)?,
1631            normal_loading_metadata,
1632            preload_adapters,
1633        )?))
1634    }
1635    fn is_gptx(&self, _: &str) -> Result<bool> {
1636        Ok(true)
1637    }
1638    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1639        Ok(Box::new(Phi3BasicConfig::deserialize(
1640            config,
1641            use_flash_attn,
1642        )?))
1643    }
1644}
1645
1646impl IsqModelLoader for Phi3Loader {
1647    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1648        Ok(vec![
1649            Regex::new(r"lm_head\.(weight|bias)$")?,
1650            // Attention
1651            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1652            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1653            // MLP
1654            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1655            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1656            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1657        ])
1658    }
1659}
1660
1661impl DeviceMappedModelLoader for Phi3Loader {
1662    fn mapped_max_act_size_elems(
1663        &self,
1664        config: &str,
1665        params: &AutoDeviceMapParams,
1666        prompt_chunksize: usize,
1667    ) -> Result<usize> {
1668        let AutoDeviceMapParams::Text {
1669            max_seq_len: _,
1670            max_batch_size,
1671        } = params
1672        else {
1673            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1674        };
1675
1676        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1677
1678        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1679    }
1680    fn non_mapped_max_act_size_elems(
1681        &self,
1682        _config: &str,
1683        _params: &AutoDeviceMapParams,
1684    ) -> Result<usize> {
1685        Ok(0)
1686    }
1687
1688    fn non_mapped_size_in_bytes(
1689        &self,
1690        config: &str,
1691        dtype: DType,
1692        weight_pack_factor: usize,
1693    ) -> Result<usize> {
1694        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1695        let elems = {
1696            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1697            let lm_head = if !cfg.tie_word_embeddings {
1698                cfg.hidden_size * cfg.vocab_size
1699            } else {
1700                0
1701            };
1702            let norm = cfg.hidden_size;
1703            embed_tokens + lm_head + norm
1704        };
1705        Ok(elems * dtype.size_in_bytes())
1706    }
1707
1708    fn layer_sizes_in_bytes(
1709        &self,
1710        config: &str,
1711        dtype: DType,
1712        weight_pack_factor: usize,
1713    ) -> Result<Vec<usize>> {
1714        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1715        let per_layer_elems = {
1716            let input_layernorm = cfg.hidden_size;
1717            let post_attention_layernorm = cfg.hidden_size;
1718
1719            let size_in = cfg.hidden_size;
1720            let head_dim = cfg.head_dim();
1721            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1722            let qkv_proj = size_in * op_size / weight_pack_factor;
1723            let o_proj =
1724                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1725
1726            let h_size = cfg.hidden_size;
1727            let i_size = cfg.intermediate_size;
1728            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1729            let down_proj = h_size * i_size / weight_pack_factor;
1730
1731            input_layernorm
1732                + post_attention_layernorm
1733                + qkv_proj
1734                + o_proj
1735                + gate_up_proj
1736                + down_proj
1737        };
1738        Ok(vec![
1739            per_layer_elems * dtype.size_in_bytes();
1740            cfg.num_hidden_layers
1741        ])
1742    }
1743
1744    fn num_layers(&self, config: &str) -> Result<usize> {
1745        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1746        Ok(cfg.num_hidden_layers)
1747    }
1748
1749    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1750        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1751
1752        let cfg = ModelConfigMetadata {
1753            max_seq_len: cfg.max_position_embeddings,
1754            num_layers: cfg.num_hidden_layers,
1755            hidden_size: cfg.hidden_size,
1756            num_kv_heads: cfg.num_key_value_heads,
1757            num_attn_heads: cfg.num_attention_heads,
1758            sliding_window: cfg.sliding_window,
1759            k_head_dim: cfg.head_dim(),
1760            v_head_dim: cfg.head_dim(),
1761        };
1762
1763        Ok(Box::new(cfg))
1764    }
1765}
1766
1767// ======================== Qwen2 loader
1768
1769#[derive(Deserialize)]
1770struct Qwen2BasicConfig {
1771    vocab_size: usize,
1772    hidden_size: usize,
1773    intermediate_size: usize,
1774    num_hidden_layers: usize,
1775    num_attention_heads: usize,
1776    num_key_value_heads: usize,
1777    max_position_embeddings: usize,
1778    sliding_window: Option<usize>,
1779    rope_theta: f64,
1780    rms_norm_eps: f64,
1781    hidden_act: Activation,
1782    quantization_config: Option<QuantizedConfig>,
1783    tie_word_embeddings: bool,
1784}
1785
1786impl Qwen2BasicConfig {
1787    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::qwen2::Config> {
1788        let basic_config: Self = serde_json::from_str(slice)?;
1789        Ok(models::qwen2::Config {
1790            vocab_size: basic_config.vocab_size,
1791            hidden_size: basic_config.hidden_size,
1792            intermediate_size: basic_config.intermediate_size,
1793            num_hidden_layers: basic_config.num_hidden_layers,
1794            num_attention_heads: basic_config.num_attention_heads,
1795            num_key_value_heads: basic_config.num_key_value_heads,
1796            hidden_act: basic_config.hidden_act,
1797            max_position_embeddings: basic_config.max_position_embeddings,
1798            rope_theta: basic_config.rope_theta,
1799            rms_norm_eps: basic_config.rms_norm_eps,
1800            sliding_window: basic_config.sliding_window,
1801            use_flash_attn,
1802            quantization_config: basic_config.quantization_config,
1803            tie_word_embeddings: basic_config.tie_word_embeddings,
1804        })
1805    }
1806}
1807
1808/// [`NormalLoader`] for a Qwen 2 model.
1809///
1810/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1811pub struct Qwen2Loader;
1812
1813impl NormalModelLoader for Qwen2Loader {
1814    fn load(
1815        &self,
1816        config: &str,
1817        use_flash_attn: bool,
1818        vb: ShardedVarBuilder,
1819        normal_loading_metadata: NormalLoadingMetadata,
1820        attention_mechanism: AttentionImplementation,
1821    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1822        Ok(Box::new(models::qwen2::Model::new(
1823            &Qwen2BasicConfig::deserialize(config, use_flash_attn)?,
1824            vb,
1825            self.is_gptx(config)?,
1826            normal_loading_metadata,
1827            attention_mechanism,
1828        )?))
1829    }
1830    fn load_xlora(
1831        &self,
1832        _config: &str,
1833        _use_flash_attn: bool,
1834        _vb: ShardedVarBuilder,
1835        _lora_config: &[((String, String), LoraConfig)],
1836        _xlora_config: Option<XLoraConfig>,
1837        _xlora_ordering: Ordering,
1838        _normal_loading_metadata: NormalLoadingMetadata,
1839        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1840    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1841        todo!()
1842    }
1843    fn is_gptx(&self, _: &str) -> Result<bool> {
1844        Ok(true)
1845    }
1846    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1847        Ok(Box::new(Qwen2BasicConfig::deserialize(
1848            config,
1849            use_flash_attn,
1850        )?))
1851    }
1852}
1853
1854impl IsqModelLoader for Qwen2Loader {
1855    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1856        Ok(vec![
1857            Regex::new(r"lm_head\.(weight|bias)$")?,
1858            // Attention
1859            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1860            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1861            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1862            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1863            // MLP
1864            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1865            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1866            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1867        ])
1868    }
1869}
1870
1871impl DeviceMappedModelLoader for Qwen2Loader {
1872    fn mapped_max_act_size_elems(
1873        &self,
1874        config: &str,
1875        params: &AutoDeviceMapParams,
1876        prompt_chunksize: usize,
1877    ) -> Result<usize> {
1878        let AutoDeviceMapParams::Text {
1879            max_seq_len: _,
1880            max_batch_size,
1881        } = params
1882        else {
1883            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1884        };
1885
1886        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1887
1888        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1889    }
1890    fn non_mapped_max_act_size_elems(
1891        &self,
1892        _config: &str,
1893        _params: &AutoDeviceMapParams,
1894    ) -> Result<usize> {
1895        Ok(0)
1896    }
1897
1898    fn non_mapped_size_in_bytes(
1899        &self,
1900        config: &str,
1901        dtype: DType,
1902        weight_pack_factor: usize,
1903    ) -> Result<usize> {
1904        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1905        let elems = {
1906            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1907            let lm_head = if !cfg.tie_word_embeddings {
1908                cfg.hidden_size * cfg.vocab_size
1909            } else {
1910                0
1911            };
1912            let norm = cfg.hidden_size;
1913            embed_tokens + lm_head + norm
1914        };
1915        Ok(elems * dtype.size_in_bytes())
1916    }
1917
1918    fn layer_sizes_in_bytes(
1919        &self,
1920        config: &str,
1921        dtype: DType,
1922        weight_pack_factor: usize,
1923    ) -> Result<Vec<usize>> {
1924        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1925        let per_layer_elems = {
1926            let input_layernorm = cfg.hidden_size;
1927            let post_attention_layernorm = cfg.hidden_size;
1928
1929            let size_in = cfg.hidden_size;
1930            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1931            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1932            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1933            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1934            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1935            let o_proj = size_q * size_in / weight_pack_factor;
1936
1937            let h_size = cfg.hidden_size;
1938            let i_size = cfg.intermediate_size;
1939            let gate_proj = h_size * i_size / weight_pack_factor;
1940            let up_proj = h_size * i_size / weight_pack_factor;
1941            let down_proj = i_size * h_size / weight_pack_factor;
1942
1943            input_layernorm
1944                + post_attention_layernorm
1945                + q_proj
1946                + k_proj
1947                + v_proj
1948                + o_proj
1949                + gate_proj
1950                + up_proj
1951                + down_proj
1952        };
1953        Ok(vec![
1954            per_layer_elems * dtype.size_in_bytes();
1955            cfg.num_hidden_layers
1956        ])
1957    }
1958
1959    fn num_layers(&self, config: &str) -> Result<usize> {
1960        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1961        Ok(cfg.num_hidden_layers)
1962    }
1963
1964    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1965        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1966
1967        let cfg = ModelConfigMetadata {
1968            max_seq_len: cfg.max_position_embeddings,
1969            num_layers: cfg.num_hidden_layers,
1970            hidden_size: cfg.hidden_size,
1971            num_kv_heads: cfg.num_key_value_heads,
1972            num_attn_heads: cfg.num_attention_heads,
1973            sliding_window: cfg.sliding_window,
1974            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1976        };
1977
1978        Ok(Box::new(cfg))
1979    }
1980}
1981
1982// ======================== Gemma2 loader
1983
1984#[derive(Deserialize)]
1985struct Gemma2BasicConfig {
1986    attention_bias: bool,
1987    head_dim: usize,
1988    // The code gemma configs include both hidden_act and hidden_activation.
1989    hidden_act: Option<Activation>,
1990    hidden_activation: Option<Activation>,
1991    hidden_size: usize,
1992    intermediate_size: usize,
1993    num_attention_heads: usize,
1994    num_hidden_layers: usize,
1995    num_key_value_heads: usize,
1996    rms_norm_eps: f64,
1997    rope_theta: f64,
1998    vocab_size: usize,
1999    sliding_window: usize,
2000    attn_logit_softcapping: Option<f64>,
2001    final_logit_softcapping: Option<f64>,
2002    query_pre_attn_scalar: usize,
2003
2004    #[serde(default = "default_max_position_embeddings")]
2005    max_position_embeddings: usize,
2006    quantization_config: Option<QuantizedConfig>,
2007    #[serde(default = "word_emb_default")]
2008    tie_word_embeddings: bool,
2009}
2010
2011impl Gemma2BasicConfig {
2012    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma2::Config> {
2013        let basic_config: Self = serde_json::from_str(slice)?;
2014        Ok(models::gemma2::Config {
2015            vocab_size: basic_config.vocab_size,
2016            hidden_size: basic_config.hidden_size,
2017            intermediate_size: basic_config.intermediate_size,
2018            num_hidden_layers: basic_config.num_hidden_layers,
2019            num_attention_heads: basic_config.num_attention_heads,
2020            num_key_value_heads: basic_config.num_key_value_heads,
2021            hidden_act: basic_config.hidden_act,
2022            hidden_activation: basic_config.hidden_activation,
2023            max_position_embeddings: basic_config.max_position_embeddings,
2024            rms_norm_eps: basic_config.rms_norm_eps,
2025            rope_theta: basic_config.rope_theta,
2026            attention_bias: basic_config.attention_bias,
2027            head_dim: basic_config.head_dim,
2028            use_flash_attn,
2029            quantization_config: basic_config.quantization_config,
2030            sliding_window: basic_config.sliding_window,
2031            attn_logit_softcapping: basic_config.attn_logit_softcapping,
2032            final_logit_softcapping: basic_config.final_logit_softcapping,
2033            query_pre_attn_scalar: basic_config.query_pre_attn_scalar,
2034            tie_word_embeddings: basic_config.tie_word_embeddings,
2035        })
2036    }
2037}
2038
2039/// [`NormalLoader`] for a Gemma2 model.
2040///
2041/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2042pub struct Gemma2Loader;
2043
2044impl NormalModelLoader for Gemma2Loader {
2045    fn load(
2046        &self,
2047        config: &str,
2048        use_flash_attn: bool,
2049        vb: ShardedVarBuilder,
2050        normal_loading_metadata: NormalLoadingMetadata,
2051        attention_mechanism: AttentionImplementation,
2052    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2053        Ok(Box::new(models::gemma2::Model::new(
2054            &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2055            vb,
2056            self.is_gptx(config)?,
2057            normal_loading_metadata,
2058            attention_mechanism,
2059        )?))
2060    }
2061    fn load_xlora(
2062        &self,
2063        config: &str,
2064        use_flash_attn: bool,
2065        vb: ShardedVarBuilder,
2066        lora_config: &[((String, String), LoraConfig)],
2067        xlora_config: Option<XLoraConfig>,
2068        xlora_ordering: Ordering,
2069        normal_loading_metadata: NormalLoadingMetadata,
2070        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2071    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2072        Ok(Box::new(xlora_models::XLoraGemma2::new(
2073            &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2074            vb,
2075            lora_config,
2076            xlora_config,
2077            xlora_ordering,
2078            self.is_gptx(config)?,
2079            normal_loading_metadata,
2080            preload_adapters,
2081        )?))
2082    }
2083    fn is_gptx(&self, _: &str) -> Result<bool> {
2084        Ok(true)
2085    }
2086    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2087        // Already will warn about it
2088        Ok(Box::new(Gemma2BasicConfig::deserialize(
2089            config,
2090            use_flash_attn,
2091        )?))
2092    }
2093}
2094
2095impl IsqModelLoader for Gemma2Loader {
2096    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2097        Ok(vec![
2098            Regex::new(r"lm_head\.(weight|bias)$")?,
2099            // Attention
2100            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2101            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2102            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2103            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2104            // MLP
2105            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2106            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2107            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2108        ])
2109    }
2110}
2111
2112impl DeviceMappedModelLoader for Gemma2Loader {
2113    fn mapped_max_act_size_elems(
2114        &self,
2115        config: &str,
2116        params: &AutoDeviceMapParams,
2117        prompt_chunksize: usize,
2118    ) -> Result<usize> {
2119        let AutoDeviceMapParams::Text {
2120            max_seq_len: _,
2121            max_batch_size,
2122        } = params
2123        else {
2124            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2125        };
2126
2127        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2128
2129        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2130    }
2131    fn non_mapped_max_act_size_elems(
2132        &self,
2133        _config: &str,
2134        _params: &AutoDeviceMapParams,
2135    ) -> Result<usize> {
2136        Ok(0)
2137    }
2138
2139    fn non_mapped_size_in_bytes(
2140        &self,
2141        config: &str,
2142        dtype: DType,
2143        weight_pack_factor: usize,
2144    ) -> Result<usize> {
2145        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2146        let elems = {
2147            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2148            let lm_head = if !cfg.tie_word_embeddings {
2149                cfg.hidden_size * cfg.vocab_size
2150            } else {
2151                0
2152            };
2153            let norm = cfg.hidden_size;
2154            embed_tokens + lm_head + norm
2155        };
2156        Ok(elems * dtype.size_in_bytes())
2157    }
2158
2159    fn layer_sizes_in_bytes(
2160        &self,
2161        config: &str,
2162        dtype: DType,
2163        weight_pack_factor: usize,
2164    ) -> Result<Vec<usize>> {
2165        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2166        let per_layer_elems = {
2167            let input_layernorm = cfg.hidden_size;
2168            let post_attention_layernorm = cfg.hidden_size;
2169
2170            let size_in = cfg.hidden_size;
2171            let size_q = cfg.head_dim * cfg.num_attention_heads;
2172            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
2173            let q_proj =
2174                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2175            let k_proj =
2176                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2177            let v_proj =
2178                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2179            let o_proj =
2180                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2181
2182            let h_size = cfg.hidden_size;
2183            let i_size = cfg.intermediate_size;
2184            let gate_proj = h_size * i_size / weight_pack_factor;
2185            let up_proj = h_size * i_size / weight_pack_factor;
2186            let down_proj = i_size * h_size / weight_pack_factor;
2187
2188            input_layernorm
2189                + post_attention_layernorm
2190                + q_proj
2191                + k_proj
2192                + v_proj
2193                + o_proj
2194                + gate_proj
2195                + up_proj
2196                + down_proj
2197        };
2198        Ok(vec![
2199            per_layer_elems * dtype.size_in_bytes();
2200            cfg.num_hidden_layers
2201        ])
2202    }
2203
2204    fn num_layers(&self, config: &str) -> Result<usize> {
2205        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2206        Ok(cfg.num_hidden_layers)
2207    }
2208    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2209        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2210
2211        let cfg = ModelConfigMetadata {
2212            max_seq_len: cfg.max_position_embeddings,
2213            num_layers: cfg.num_hidden_layers,
2214            hidden_size: cfg.hidden_size,
2215            num_kv_heads: cfg.num_key_value_heads,
2216            num_attn_heads: cfg.num_attention_heads,
2217            sliding_window: None, // None to be more forgiving, some do not
2218            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2219            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2220        };
2221
2222        Ok(Box::new(cfg))
2223    }
2224}
2225
2226// ======================== Starcoder2 loader
2227
2228#[derive(Deserialize, Debug)]
2229struct Starcoder2BasicConfig {
2230    vocab_size: usize,
2231    hidden_size: usize,
2232    intermediate_size: usize,
2233    num_hidden_layers: usize,
2234    num_attention_heads: usize,
2235    num_key_value_heads: usize,
2236    hidden_act: Activation,
2237    max_position_embeddings: usize,
2238    norm_epsilon: f64,
2239    rope_theta: f64,
2240    use_bias: bool,
2241    sliding_window: Option<usize>,
2242    quantization_config: Option<QuantizedConfig>,
2243    #[serde(default = "word_emb_default")]
2244    tie_word_embeddings: bool,
2245}
2246
2247impl Starcoder2BasicConfig {
2248    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::starcoder2::Config> {
2249        let basic_config: Self = serde_json::from_str(slice)?;
2250        Ok(models::starcoder2::Config {
2251            vocab_size: basic_config.vocab_size,
2252            hidden_size: basic_config.hidden_size,
2253            intermediate_size: basic_config.intermediate_size,
2254            num_hidden_layers: basic_config.num_hidden_layers,
2255            num_attention_heads: basic_config.num_attention_heads,
2256            num_key_value_heads: basic_config.num_key_value_heads,
2257            hidden_act: basic_config.hidden_act,
2258            max_position_embeddings: basic_config.max_position_embeddings,
2259            rope_theta: basic_config.rope_theta,
2260            sliding_window: basic_config.sliding_window,
2261            use_flash_attn,
2262            norm_epsilon: basic_config.norm_epsilon,
2263            use_bias: basic_config.use_bias,
2264            quantization_config: basic_config.quantization_config,
2265            tie_word_embeddings: basic_config.tie_word_embeddings,
2266        })
2267    }
2268}
2269
2270/// [`NormalLoader`] for a Starcoder2 model.
2271///
2272/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2273pub struct Starcoder2Loader;
2274
2275impl NormalModelLoader for Starcoder2Loader {
2276    fn load(
2277        &self,
2278        config: &str,
2279        use_flash_attn: bool,
2280        vb: ShardedVarBuilder,
2281        normal_loading_metadata: NormalLoadingMetadata,
2282        attention_mechanism: AttentionImplementation,
2283    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2284        Ok(Box::new(models::starcoder2::Model::new(
2285            &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2286            vb,
2287            self.is_gptx(config)?,
2288            normal_loading_metadata,
2289            attention_mechanism,
2290        )?))
2291    }
2292    fn load_xlora(
2293        &self,
2294        config: &str,
2295        use_flash_attn: bool,
2296        vb: ShardedVarBuilder,
2297        lora_config: &[((String, String), LoraConfig)],
2298        xlora_config: Option<XLoraConfig>,
2299        xlora_ordering: Ordering,
2300        normal_loading_metadata: NormalLoadingMetadata,
2301        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2302    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2303        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2304            &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2305            vb,
2306            lora_config,
2307            xlora_config,
2308            xlora_ordering,
2309            self.is_gptx(config)?,
2310            normal_loading_metadata,
2311            preload_adapters,
2312        )?))
2313    }
2314    fn is_gptx(&self, _: &str) -> Result<bool> {
2315        Ok(true)
2316    }
2317    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2318        Ok(Box::new(serde_json::from_str::<Starcoder2BasicConfig>(
2319            config,
2320        )?))
2321    }
2322}
2323
2324impl IsqModelLoader for Starcoder2Loader {
2325    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2326        Ok(vec![
2327            Regex::new(r"lm_head\.(weight|bias)$")?,
2328            // Attention
2329            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2330            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2331            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2332            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2333            // MLP
2334            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2335            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2336        ])
2337    }
2338}
2339
2340impl DeviceMappedModelLoader for Starcoder2Loader {
2341    fn mapped_max_act_size_elems(
2342        &self,
2343        config: &str,
2344        params: &AutoDeviceMapParams,
2345        prompt_chunksize: usize,
2346    ) -> Result<usize> {
2347        let AutoDeviceMapParams::Text {
2348            max_seq_len: _,
2349            max_batch_size,
2350        } = params
2351        else {
2352            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2353        };
2354
2355        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2356
2357        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2358    }
2359    fn non_mapped_max_act_size_elems(
2360        &self,
2361        _config: &str,
2362        _params: &AutoDeviceMapParams,
2363    ) -> Result<usize> {
2364        Ok(0)
2365    }
2366
2367    fn non_mapped_size_in_bytes(
2368        &self,
2369        config: &str,
2370        dtype: DType,
2371        weight_pack_factor: usize,
2372    ) -> Result<usize> {
2373        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2374        let elems = {
2375            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2376            let lm_head = if !cfg.tie_word_embeddings {
2377                cfg.hidden_size * cfg.vocab_size
2378            } else {
2379                0
2380            };
2381            let norm = cfg.hidden_size + cfg.hidden_size;
2382            embed_tokens + lm_head + norm
2383        };
2384        Ok(elems * dtype.size_in_bytes())
2385    }
2386
2387    fn layer_sizes_in_bytes(
2388        &self,
2389        config: &str,
2390        dtype: DType,
2391        weight_pack_factor: usize,
2392    ) -> Result<Vec<usize>> {
2393        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2394        let per_layer_elems = {
2395            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2396            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2397
2398            let size_in = cfg.hidden_size;
2399            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2400            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2401            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2402            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2403            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2404            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2405
2406            let h_size = cfg.hidden_size;
2407            let i_size = cfg.intermediate_size;
2408            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2409            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2410
2411            input_layernorm
2412                + post_attention_layernorm
2413                + q_proj
2414                + k_proj
2415                + v_proj
2416                + o_proj
2417                + fc1
2418                + fc2
2419        };
2420        Ok(vec![
2421            per_layer_elems * dtype.size_in_bytes();
2422            cfg.num_hidden_layers
2423        ])
2424    }
2425
2426    fn num_layers(&self, config: &str) -> Result<usize> {
2427        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2428        Ok(cfg.num_hidden_layers)
2429    }
2430
2431    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2432        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2433
2434        let cfg = ModelConfigMetadata {
2435            max_seq_len: cfg.max_position_embeddings,
2436            num_layers: cfg.num_hidden_layers,
2437            hidden_size: cfg.hidden_size,
2438            num_kv_heads: cfg.num_key_value_heads,
2439            num_attn_heads: cfg.num_attention_heads,
2440            sliding_window: cfg.sliding_window,
2441            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2442            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2443        };
2444
2445        Ok(Box::new(cfg))
2446    }
2447}
2448
2449// ======================== Phi3 loader
2450
2451#[derive(Deserialize)]
2452struct Phi3_5MoEBasicConfig {
2453    vocab_size: usize,
2454    hidden_act: Activation,
2455    hidden_size: usize,
2456    intermediate_size: usize,
2457    num_hidden_layers: usize,
2458    num_attention_heads: usize,
2459    num_key_value_heads: usize,
2460    rms_norm_eps: f64,
2461    rope_theta: f64,
2462    rope_scaling: Option<PhiRopeScalingConfig>,
2463    max_position_embeddings: usize,
2464    original_max_position_embeddings: usize,
2465    sliding_window: Option<usize>,
2466    quantization_config: Option<QuantizedConfig>,
2467    lm_head_bias: bool,
2468    attention_bias: bool,
2469    num_local_experts: usize,
2470    router_jitter_noise: f64,
2471    #[serde(default = "word_emb_default")]
2472    tie_word_embeddings: bool,
2473}
2474
2475impl Phi3_5MoEBasicConfig {
2476    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3_5_moe::Config> {
2477        let basic_config: Self = serde_json::from_str(slice)?;
2478        Ok(models::phi3_5_moe::Config {
2479            vocab_size: basic_config.vocab_size,
2480            hidden_size: basic_config.hidden_size,
2481            intermediate_size: basic_config.intermediate_size,
2482            num_hidden_layers: basic_config.num_hidden_layers,
2483            num_attention_heads: basic_config.num_attention_heads,
2484            num_key_value_heads: basic_config.num_key_value_heads,
2485            hidden_act: basic_config.hidden_act,
2486            max_position_embeddings: basic_config.max_position_embeddings,
2487            rope_theta: basic_config.rope_theta,
2488            rms_norm_eps: basic_config.rms_norm_eps,
2489            rope_scaling: basic_config.rope_scaling,
2490            original_max_position_embeddings: basic_config.original_max_position_embeddings,
2491            use_flash_attn,
2492            sliding_window: basic_config.sliding_window,
2493            quantization_config: basic_config.quantization_config,
2494            lm_head_bias: basic_config.lm_head_bias,
2495            attention_bias: basic_config.attention_bias,
2496            num_local_experts: basic_config.num_local_experts,
2497            router_jitter_noise: basic_config.router_jitter_noise,
2498            tie_word_embeddings: basic_config.tie_word_embeddings,
2499        })
2500    }
2501}
2502
2503/// [`NormalLoader`] for a Phi 3.5 MoE model.
2504///
2505/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2506pub struct Phi3_5MoELoader;
2507
2508impl NormalModelLoader for Phi3_5MoELoader {
2509    fn load(
2510        &self,
2511        config: &str,
2512        use_flash_attn: bool,
2513        vb: ShardedVarBuilder,
2514        normal_loading_metadata: NormalLoadingMetadata,
2515        attention_mechanism: AttentionImplementation,
2516    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2517        Ok(Box::new(models::phi3_5_moe::Model::new(
2518            &Phi3_5MoEBasicConfig::deserialize(config, use_flash_attn)?,
2519            vb,
2520            self.is_gptx(config)?,
2521            normal_loading_metadata,
2522            attention_mechanism,
2523        )?))
2524    }
2525    fn load_xlora(
2526        &self,
2527        config: &str,
2528        use_flash_attn: bool,
2529        vb: ShardedVarBuilder,
2530        lora_config: &[((String, String), LoraConfig)],
2531        xlora_config: Option<XLoraConfig>,
2532        xlora_ordering: Ordering,
2533        normal_loading_metadata: NormalLoadingMetadata,
2534        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2535    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2536        Ok(Box::new(xlora_models::XLoraPhi3::new(
2537            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
2538            vb,
2539            lora_config,
2540            xlora_config,
2541            xlora_ordering,
2542            self.is_gptx(config)?,
2543            normal_loading_metadata,
2544            preload_adapters,
2545        )?))
2546    }
2547    fn is_gptx(&self, _: &str) -> Result<bool> {
2548        Ok(true)
2549    }
2550    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2551        Ok(Box::new(Phi3_5MoEBasicConfig::deserialize(
2552            config,
2553            use_flash_attn,
2554        )?))
2555    }
2556}
2557
2558impl IsqModelLoader for Phi3_5MoELoader {
2559    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2560        Ok(vec![
2561            Regex::new(r"lm_head\.(weight|bias)$")?,
2562            // Attention
2563            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2564            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2565            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2566            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2567            // MLP
2568            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2569            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2570            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2571        ])
2572    }
2573
2574    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2575        Ok(vec![
2576            Regex::new(r"lm_head\.(weight|bias)$")?,
2577            // MLP
2578            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2579            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2580            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2581        ])
2582    }
2583}
2584
2585impl DeviceMappedModelLoader for Phi3_5MoELoader {
2586    fn mapped_max_act_size_elems(
2587        &self,
2588        config: &str,
2589        params: &AutoDeviceMapParams,
2590        prompt_chunksize: usize,
2591    ) -> Result<usize> {
2592        let AutoDeviceMapParams::Text {
2593            max_seq_len: _,
2594            max_batch_size,
2595        } = params
2596        else {
2597            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2598        };
2599
2600        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2601
2602        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2603    }
2604    fn non_mapped_max_act_size_elems(
2605        &self,
2606        _config: &str,
2607        _params: &AutoDeviceMapParams,
2608    ) -> Result<usize> {
2609        Ok(0)
2610    }
2611
2612    fn non_mapped_size_in_bytes(
2613        &self,
2614        config: &str,
2615        dtype: DType,
2616        weight_pack_factor: usize,
2617    ) -> Result<usize> {
2618        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2619        let elems = {
2620            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2621            let lm_head = if !cfg.tie_word_embeddings {
2622                cfg.hidden_size * cfg.vocab_size
2623            } else {
2624                0
2625            };
2626            let norm = cfg.hidden_size;
2627            embed_tokens + lm_head + norm
2628        };
2629        Ok(elems * dtype.size_in_bytes())
2630    }
2631
2632    fn layer_sizes_in_bytes(
2633        &self,
2634        config: &str,
2635        dtype: DType,
2636        weight_pack_factor: usize,
2637    ) -> Result<Vec<usize>> {
2638        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2639        let per_layer_elems = {
2640            let input_layernorm = cfg.hidden_size;
2641            let post_attention_layernorm = cfg.hidden_size;
2642
2643            let size_in = cfg.hidden_size;
2644            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2645            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2646            let q_proj =
2647                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2648            let k_proj =
2649                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2650            let v_proj =
2651                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2652            let o_proj =
2653                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2654
2655            let moe_block = {
2656                let gate = cfg.hidden_size * cfg.num_local_experts;
2657                // Assume quantizing weight pack factor
2658                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2659                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2660                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2661                gate + cfg.num_local_experts * w1
2662                    + cfg.num_local_experts * w2
2663                    + cfg.num_local_experts * w3
2664            };
2665
2666            input_layernorm
2667                + post_attention_layernorm
2668                + q_proj
2669                + k_proj
2670                + v_proj
2671                + o_proj
2672                + moe_block
2673        };
2674        Ok(vec![
2675            per_layer_elems * dtype.size_in_bytes();
2676            cfg.num_hidden_layers
2677        ])
2678    }
2679
2680    fn num_layers(&self, config: &str) -> Result<usize> {
2681        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2682        Ok(cfg.num_hidden_layers)
2683    }
2684
2685    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2686        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2687
2688        let cfg = ModelConfigMetadata {
2689            max_seq_len: cfg.max_position_embeddings,
2690            num_layers: cfg.num_hidden_layers,
2691            hidden_size: cfg.hidden_size,
2692            num_kv_heads: cfg.num_key_value_heads,
2693            num_attn_heads: cfg.num_attention_heads,
2694            sliding_window: cfg.sliding_window,
2695            k_head_dim: cfg.head_dim(),
2696            v_head_dim: cfg.head_dim(),
2697        };
2698
2699        Ok(Box::new(cfg))
2700    }
2701}
2702
2703/// [`NormalLoader`] for a DeepSeekV2 model.
2704///
2705/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2706pub struct DeepSeekV2Loader;
2707
2708impl NormalModelLoader for DeepSeekV2Loader {
2709    fn load(
2710        &self,
2711        config: &str,
2712        use_flash_attn: bool,
2713        vb: ShardedVarBuilder,
2714        normal_loading_metadata: NormalLoadingMetadata,
2715        attention_mechanism: AttentionImplementation,
2716    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2717        let mut cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2718        cfg.use_flash_attn = use_flash_attn;
2719        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2720            &cfg,
2721            vb,
2722            self.is_gptx(config)?,
2723            normal_loading_metadata,
2724            attention_mechanism,
2725        )?))
2726    }
2727    fn load_xlora(
2728        &self,
2729        _config: &str,
2730        _use_flash_attn: bool,
2731        _vb: ShardedVarBuilder,
2732        _lora_config: &[((String, String), LoraConfig)],
2733        _xlora_config: Option<XLoraConfig>,
2734        _xlora_ordering: Ordering,
2735        _normal_loading_metadata: NormalLoadingMetadata,
2736        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2737    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2738        todo!()
2739    }
2740    fn is_gptx(&self, _: &str) -> Result<bool> {
2741        Ok(true)
2742    }
2743    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2744        let mut config: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2745        config.use_flash_attn = use_flash_attn;
2746        Ok(Box::new(config))
2747    }
2748}
2749
2750impl IsqModelLoader for DeepSeekV2Loader {
2751    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2752        let mut data = vec![
2753            Regex::new(r"lm_head\.(weight|bias)$")?,
2754            // Attention
2755            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2756            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2757            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2758        ];
2759        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2760        if cfg.q_lora_rank.is_some() {
2761            data.extend(vec![
2762                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2763                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2764            ]);
2765        } else {
2766            data.push(Regex::new(
2767                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2768            )?);
2769        }
2770        for layer_idx in 0..cfg.num_hidden_layers {
2771            if cfg.n_routed_experts.is_some()
2772                && layer_idx >= cfg.first_k_dense_replace
2773                && layer_idx % cfg.moe_layer_freq == 0
2774            {
2775                for i in 0..cfg.n_routed_experts.unwrap() {
2776                    data.extend(vec![
2777                        Regex::new(&format!(
2778                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2779                        ))?,
2780                        Regex::new(&format!(
2781                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2782                        ))?,
2783                        Regex::new(&format!(
2784                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2785                        ))?,
2786                    ]);
2787                }
2788                if cfg.n_shared_experts.is_some() {
2789                    data.extend(vec![
2790                        Regex::new(&format!(
2791                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2792                        ))?,
2793                        Regex::new(&format!(
2794                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2795                        ))?,
2796                        Regex::new(&format!(
2797                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2798                        ))?,
2799                    ]);
2800                }
2801            } else {
2802                data.extend(vec![
2803                    Regex::new(&format!(
2804                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2805                    ))?,
2806                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2807                    Regex::new(&format!(
2808                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2809                    ))?,
2810                ]);
2811            };
2812        }
2813        Ok(data)
2814    }
2815
2816    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2817        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2818        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2819        for layer_idx in 0..cfg.num_hidden_layers {
2820            if cfg.n_routed_experts.is_some()
2821                && layer_idx >= cfg.first_k_dense_replace
2822                && layer_idx % cfg.moe_layer_freq == 0
2823            {
2824                for i in 0..cfg.n_routed_experts.unwrap() {
2825                    data.extend(vec![
2826                        Regex::new(&format!(
2827                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2828                        ))?,
2829                        Regex::new(&format!(
2830                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2831                        ))?,
2832                        Regex::new(&format!(
2833                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2834                        ))?,
2835                    ]);
2836                }
2837                if cfg.n_shared_experts.is_some() {
2838                    data.extend(vec![
2839                        Regex::new(&format!(
2840                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2841                        ))?,
2842                        Regex::new(&format!(
2843                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2844                        ))?,
2845                        Regex::new(&format!(
2846                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2847                        ))?,
2848                    ]);
2849                }
2850            } else {
2851                data.extend(vec![
2852                    Regex::new(&format!(
2853                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2854                    ))?,
2855                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2856                    Regex::new(&format!(
2857                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2858                    ))?,
2859                ]);
2860            };
2861        }
2862        Ok(data)
2863    }
2864}
2865
2866impl DeviceMappedModelLoader for DeepSeekV2Loader {
2867    fn mapped_max_act_size_elems(
2868        &self,
2869        config: &str,
2870        params: &AutoDeviceMapParams,
2871        prompt_chunksize: usize,
2872    ) -> Result<usize> {
2873        let AutoDeviceMapParams::Text {
2874            max_seq_len: _,
2875            max_batch_size,
2876        } = params
2877        else {
2878            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2879        };
2880
2881        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2882
2883        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2884    }
2885    fn non_mapped_max_act_size_elems(
2886        &self,
2887        _config: &str,
2888        _params: &AutoDeviceMapParams,
2889    ) -> Result<usize> {
2890        Ok(0)
2891    }
2892
2893    fn non_mapped_size_in_bytes(
2894        &self,
2895        config: &str,
2896        dtype: DType,
2897        weight_pack_factor: usize,
2898    ) -> Result<usize> {
2899        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2900        let elems = {
2901            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2902            let lm_head = if !cfg.tie_word_embeddings {
2903                cfg.hidden_size * cfg.vocab_size
2904            } else {
2905                0
2906            };
2907            let norm = cfg.hidden_size;
2908            embed_tokens + lm_head + norm
2909        };
2910        Ok(elems * dtype.size_in_bytes())
2911    }
2912
2913    fn layer_sizes_in_bytes(
2914        &self,
2915        config: &str,
2916        dtype: DType,
2917        weight_pack_factor: usize,
2918    ) -> Result<Vec<usize>> {
2919        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2920        let mut per_layer_elems = Vec::new();
2921
2922        for layer_idx in 0..cfg.num_hidden_layers {
2923            let input_layernorm = cfg.hidden_size;
2924            let post_attention_layernorm = cfg.hidden_size;
2925
2926            let q_proj = match cfg.q_lora_rank {
2927                Some(lora_rank) => {
2928                    let a = cfg.hidden_size * lora_rank;
2929                    let norm = lora_rank;
2930                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2931                    a + norm + b
2932                }
2933                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2934            };
2935            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2936                / weight_pack_factor
2937                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2938            let kv_a_layernorm = cfg.kv_lora_rank;
2939            let kv_b_proj = cfg.kv_lora_rank
2940                * cfg.num_attention_heads
2941                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2942                / weight_pack_factor;
2943            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2944                / weight_pack_factor
2945                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2946
2947            let moe_block = {
2948                let mut sum = 0;
2949                if cfg.n_routed_experts.is_some()
2950                    && layer_idx >= cfg.first_k_dense_replace
2951                    && layer_idx % cfg.moe_layer_freq == 0
2952                {
2953                    let h_size = cfg.hidden_size;
2954                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2955                        * cfg.n_routed_experts.unwrap();
2956                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2957                        * cfg.n_routed_experts.unwrap();
2958                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2959                        * cfg.n_routed_experts.unwrap();
2960                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2961                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2962                            / weight_pack_factor;
2963                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2964                            / weight_pack_factor;
2965                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2966                            / weight_pack_factor;
2967                        gate_proj + up_proj + down_proj
2968                    } else {
2969                        0
2970                    };
2971                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2972                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2973                } else {
2974                    let h_size = cfg.hidden_size;
2975                    let i_size = cfg.intermediate_size;
2976                    let gate_proj = h_size * i_size / weight_pack_factor;
2977                    let up_proj = h_size * i_size / weight_pack_factor;
2978                    let down_proj = i_size * h_size / weight_pack_factor;
2979                    sum += gate_proj + up_proj + down_proj;
2980                }
2981                sum
2982            };
2983
2984            per_layer_elems.push(
2985                input_layernorm
2986                    + post_attention_layernorm
2987                    + q_proj
2988                    + kv_a_layernorm
2989                    + kv_a_proj_with_mqa
2990                    + kv_b_proj
2991                    + o_proj
2992                    + moe_block,
2993            );
2994        }
2995
2996        Ok(per_layer_elems
2997            .into_iter()
2998            .map(|x| x * dtype.size_in_bytes())
2999            .collect())
3000    }
3001
3002    fn num_layers(&self, config: &str) -> Result<usize> {
3003        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
3004        Ok(cfg.num_hidden_layers)
3005    }
3006
3007    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3008        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
3009
3010        let cfg = ModelConfigMetadata {
3011            max_seq_len: cfg.max_position_embeddings,
3012            num_layers: cfg.num_hidden_layers,
3013            hidden_size: cfg.hidden_size,
3014            num_kv_heads: cfg.num_attention_heads,
3015            num_attn_heads: cfg.num_attention_heads,
3016            sliding_window: None,
3017            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3018            v_head_dim: cfg.v_head_dim,
3019        };
3020
3021        Ok(Box::new(cfg))
3022    }
3023}
3024
3025/// [`NormalLoader`] for a DeepSeekV3 model.
3026///
3027/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3028pub struct DeepSeekV3Loader;
3029
3030impl NormalModelLoader for DeepSeekV3Loader {
3031    fn load(
3032        &self,
3033        config: &str,
3034        use_flash_attn: bool,
3035        vb: ShardedVarBuilder,
3036        normal_loading_metadata: NormalLoadingMetadata,
3037        attention_mechanism: AttentionImplementation,
3038    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3039        let mut cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3040        cfg.use_flash_attn = use_flash_attn;
3041        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
3042            &cfg,
3043            vb,
3044            self.is_gptx(config)?,
3045            normal_loading_metadata,
3046            attention_mechanism,
3047        )?))
3048    }
3049    fn load_xlora(
3050        &self,
3051        _config: &str,
3052        _use_flash_attn: bool,
3053        _vb: ShardedVarBuilder,
3054        _lora_config: &[((String, String), LoraConfig)],
3055        _xlora_config: Option<XLoraConfig>,
3056        _xlora_ordering: Ordering,
3057        _normal_loading_metadata: NormalLoadingMetadata,
3058        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3059    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3060        todo!()
3061    }
3062    fn is_gptx(&self, _: &str) -> Result<bool> {
3063        Ok(true)
3064    }
3065    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3066        let mut config: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3067        config.use_flash_attn = use_flash_attn;
3068        Ok(Box::new(config))
3069    }
3070}
3071
3072impl IsqModelLoader for DeepSeekV3Loader {
3073    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3074        let mut data = vec![
3075            Regex::new(r"lm_head\.(weight|bias)$")?,
3076            // Attention
3077            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3078            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3079            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3080        ];
3081        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3082        if cfg.q_lora_rank.is_some() {
3083            data.extend(vec![
3084                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3085                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3086            ]);
3087        } else {
3088            data.push(Regex::new(
3089                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
3090            )?);
3091        }
3092        for layer_idx in 0..cfg.num_hidden_layers {
3093            if cfg.n_routed_experts.is_some()
3094                && layer_idx >= cfg.first_k_dense_replace
3095                && layer_idx % cfg.moe_layer_freq == 0
3096            {
3097                for i in 0..cfg.n_routed_experts.unwrap() {
3098                    data.extend(vec![
3099                        Regex::new(&format!(
3100                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3101                        ))?,
3102                        Regex::new(&format!(
3103                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3104                        ))?,
3105                        Regex::new(&format!(
3106                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3107                        ))?,
3108                    ]);
3109                }
3110                if cfg.n_shared_experts.is_some() {
3111                    data.extend(vec![
3112                        Regex::new(&format!(
3113                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3114                        ))?,
3115                        Regex::new(&format!(
3116                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3117                        ))?,
3118                        Regex::new(&format!(
3119                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3120                        ))?,
3121                    ]);
3122                }
3123            } else {
3124                data.extend(vec![
3125                    Regex::new(&format!(
3126                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3127                    ))?,
3128                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3129                    Regex::new(&format!(
3130                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3131                    ))?,
3132                ]);
3133            };
3134        }
3135        Ok(data)
3136    }
3137
3138    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3139        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3140        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3141        for layer_idx in 0..cfg.num_hidden_layers {
3142            if cfg.n_routed_experts.is_some()
3143                && layer_idx >= cfg.first_k_dense_replace
3144                && layer_idx % cfg.moe_layer_freq == 0
3145            {
3146                for i in 0..cfg.n_routed_experts.unwrap() {
3147                    data.extend(vec![
3148                        Regex::new(&format!(
3149                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3150                        ))?,
3151                        Regex::new(&format!(
3152                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3153                        ))?,
3154                        Regex::new(&format!(
3155                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3156                        ))?,
3157                    ]);
3158                }
3159                if cfg.n_shared_experts.is_some() {
3160                    data.extend(vec![
3161                        Regex::new(&format!(
3162                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3163                        ))?,
3164                        Regex::new(&format!(
3165                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3166                        ))?,
3167                        Regex::new(&format!(
3168                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3169                        ))?,
3170                    ]);
3171                }
3172            } else {
3173                data.extend(vec![
3174                    Regex::new(&format!(
3175                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3176                    ))?,
3177                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3178                    Regex::new(&format!(
3179                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3180                    ))?,
3181                ]);
3182            };
3183        }
3184        Ok(data)
3185    }
3186}
3187
3188impl DeviceMappedModelLoader for DeepSeekV3Loader {
3189    fn mapped_max_act_size_elems(
3190        &self,
3191        config: &str,
3192        params: &AutoDeviceMapParams,
3193        prompt_chunksize: usize,
3194    ) -> Result<usize> {
3195        let AutoDeviceMapParams::Text {
3196            max_seq_len: _,
3197            max_batch_size,
3198        } = params
3199        else {
3200            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3201        };
3202
3203        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3204
3205        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3206    }
3207    fn non_mapped_max_act_size_elems(
3208        &self,
3209        _config: &str,
3210        _params: &AutoDeviceMapParams,
3211    ) -> Result<usize> {
3212        Ok(0)
3213    }
3214
3215    fn non_mapped_size_in_bytes(
3216        &self,
3217        config: &str,
3218        dtype: DType,
3219        weight_pack_factor: usize,
3220    ) -> Result<usize> {
3221        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3222        let elems = {
3223            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3224            let lm_head = if !cfg.tie_word_embeddings {
3225                cfg.hidden_size * cfg.vocab_size
3226            } else {
3227                0
3228            };
3229            let norm = cfg.hidden_size;
3230            embed_tokens + lm_head + norm
3231        };
3232        Ok(elems * dtype.size_in_bytes())
3233    }
3234
3235    fn layer_sizes_in_bytes(
3236        &self,
3237        config: &str,
3238        dtype: DType,
3239        weight_pack_factor: usize,
3240    ) -> Result<Vec<usize>> {
3241        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3242        let mut per_layer_elems = Vec::new();
3243
3244        for layer_idx in 0..cfg.num_hidden_layers {
3245            let input_layernorm = cfg.hidden_size;
3246            let post_attention_layernorm = cfg.hidden_size;
3247
3248            let q_proj = match cfg.q_lora_rank {
3249                Some(lora_rank) => {
3250                    let a = cfg.hidden_size * lora_rank;
3251                    let norm = lora_rank;
3252                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
3253                    a + norm + b
3254                }
3255                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
3256            };
3257            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3258                / weight_pack_factor
3259                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3260            let kv_a_layernorm = cfg.kv_lora_rank;
3261            let kv_b_proj = cfg.kv_lora_rank
3262                * cfg.num_attention_heads
3263                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3264                / weight_pack_factor;
3265            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3266                / weight_pack_factor
3267                + bias_if!(cfg.attention_bias, cfg.hidden_size);
3268
3269            let moe_block = {
3270                let mut sum = 0;
3271                if cfg.n_routed_experts.is_some()
3272                    && layer_idx >= cfg.first_k_dense_replace
3273                    && layer_idx % cfg.moe_layer_freq == 0
3274                {
3275                    let h_size = cfg.hidden_size;
3276                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3277                        * cfg.n_routed_experts.unwrap();
3278                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3279                        * cfg.n_routed_experts.unwrap();
3280                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3281                        * cfg.n_routed_experts.unwrap();
3282                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3283                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3284                            / weight_pack_factor;
3285                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3286                            / weight_pack_factor;
3287                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3288                            / weight_pack_factor;
3289                        gate_proj + up_proj + down_proj
3290                    } else {
3291                        0
3292                    };
3293                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3294                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3295                } else {
3296                    let h_size = cfg.hidden_size;
3297                    let i_size = cfg.intermediate_size;
3298                    let gate_proj = h_size * i_size / weight_pack_factor;
3299                    let up_proj = h_size * i_size / weight_pack_factor;
3300                    let down_proj = i_size * h_size / weight_pack_factor;
3301                    sum += gate_proj + up_proj + down_proj;
3302                }
3303                sum
3304            };
3305
3306            per_layer_elems.push(
3307                input_layernorm
3308                    + post_attention_layernorm
3309                    + q_proj
3310                    + kv_a_layernorm
3311                    + kv_a_proj_with_mqa
3312                    + kv_b_proj
3313                    + o_proj
3314                    + moe_block,
3315            );
3316        }
3317
3318        Ok(per_layer_elems
3319            .into_iter()
3320            .map(|x| x * dtype.size_in_bytes())
3321            .collect())
3322    }
3323
3324    fn num_layers(&self, config: &str) -> Result<usize> {
3325        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3326        Ok(cfg.num_hidden_layers)
3327    }
3328
3329    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3330        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3331
3332        let cfg = ModelConfigMetadata {
3333            max_seq_len: cfg.max_position_embeddings,
3334            num_layers: cfg.num_hidden_layers,
3335            hidden_size: cfg.hidden_size,
3336            num_kv_heads: cfg.num_attention_heads,
3337            num_attn_heads: cfg.num_attention_heads,
3338            sliding_window: None,
3339            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3340            v_head_dim: cfg.v_head_dim,
3341        };
3342
3343        Ok(Box::new(cfg))
3344    }
3345}
3346
3347/// [`NormalLoader`] for a Qwen 3 model.
3348///
3349/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3350pub struct Qwen3Loader;
3351
3352impl NormalModelLoader for Qwen3Loader {
3353    fn load(
3354        &self,
3355        config: &str,
3356        use_flash_attn: bool,
3357        vb: ShardedVarBuilder,
3358        normal_loading_metadata: NormalLoadingMetadata,
3359        attention_mechanism: AttentionImplementation,
3360    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3361        let mut cfg: models::qwen3::Config = serde_json::from_str(config)?;
3362        cfg.use_flash_attn = use_flash_attn;
3363
3364        Ok(Box::new(models::qwen3::Model::new(
3365            &cfg,
3366            vb,
3367            self.is_gptx(config)?,
3368            normal_loading_metadata,
3369            attention_mechanism,
3370        )?))
3371    }
3372    fn load_xlora(
3373        &self,
3374        _config: &str,
3375        _use_flash_attn: bool,
3376        _vb: ShardedVarBuilder,
3377        _lora_config: &[((String, String), LoraConfig)],
3378        _xlora_config: Option<XLoraConfig>,
3379        _xlora_ordering: Ordering,
3380        _normal_loading_metadata: NormalLoadingMetadata,
3381        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3382    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3383        todo!()
3384    }
3385    fn is_gptx(&self, _: &str) -> Result<bool> {
3386        Ok(true)
3387    }
3388    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3389        let mut cfg: models::qwen3::Config = serde_json::from_str(config)?;
3390        cfg.use_flash_attn = use_flash_attn;
3391
3392        Ok(Box::new(cfg))
3393    }
3394}
3395
3396impl IsqModelLoader for Qwen3Loader {
3397    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3398        Ok(vec![
3399            Regex::new(r"lm_head\.(weight|bias)$")?,
3400            // Attention
3401            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3402            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3403            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3404            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3405            // MLP
3406            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3407            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3408            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3409        ])
3410    }
3411}
3412
3413impl DeviceMappedModelLoader for Qwen3Loader {
3414    fn mapped_max_act_size_elems(
3415        &self,
3416        config: &str,
3417        params: &AutoDeviceMapParams,
3418        prompt_chunksize: usize,
3419    ) -> Result<usize> {
3420        let AutoDeviceMapParams::Text {
3421            max_seq_len: _,
3422            max_batch_size,
3423        } = params
3424        else {
3425            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3426        };
3427
3428        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3429
3430        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3431    }
3432    fn non_mapped_max_act_size_elems(
3433        &self,
3434        _config: &str,
3435        _params: &AutoDeviceMapParams,
3436    ) -> Result<usize> {
3437        Ok(0)
3438    }
3439
3440    fn non_mapped_size_in_bytes(
3441        &self,
3442        config: &str,
3443        dtype: DType,
3444        weight_pack_factor: usize,
3445    ) -> Result<usize> {
3446        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3447        let elems = {
3448            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3449            let lm_head = if !cfg.tie_word_embeddings {
3450                cfg.hidden_size * cfg.vocab_size
3451            } else {
3452                0
3453            };
3454            let norm = cfg.hidden_size;
3455            embed_tokens + lm_head + norm
3456        };
3457        Ok(elems * dtype.size_in_bytes())
3458    }
3459
3460    fn layer_sizes_in_bytes(
3461        &self,
3462        config: &str,
3463        dtype: DType,
3464        weight_pack_factor: usize,
3465    ) -> Result<Vec<usize>> {
3466        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3467        let per_layer_elems = {
3468            let input_layernorm = cfg.hidden_size;
3469            let post_attention_layernorm = cfg.hidden_size;
3470
3471            let size_in = cfg.hidden_size;
3472            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3473            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3474            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3475            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3476            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3477            let o_proj = size_q * size_in / weight_pack_factor;
3478
3479            let h_size = cfg.hidden_size;
3480            let i_size = cfg.intermediate_size;
3481            let gate_proj = h_size * i_size / weight_pack_factor;
3482            let up_proj = h_size * i_size / weight_pack_factor;
3483            let down_proj = i_size * h_size / weight_pack_factor;
3484
3485            let q_norm = cfg.head_dim();
3486            let k_norm = cfg.head_dim();
3487
3488            input_layernorm
3489                + post_attention_layernorm
3490                + q_proj
3491                + k_proj
3492                + v_proj
3493                + o_proj
3494                + gate_proj
3495                + up_proj
3496                + down_proj
3497                + q_norm
3498                + k_norm
3499        };
3500        Ok(vec![
3501            per_layer_elems * dtype.size_in_bytes();
3502            cfg.num_hidden_layers
3503        ])
3504    }
3505
3506    fn num_layers(&self, config: &str) -> Result<usize> {
3507        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3508        Ok(cfg.num_hidden_layers)
3509    }
3510
3511    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3512        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3513
3514        let cfg = ModelConfigMetadata {
3515            max_seq_len: cfg.max_position_embeddings,
3516            num_layers: cfg.num_hidden_layers,
3517            hidden_size: cfg.hidden_size,
3518            num_kv_heads: cfg.num_key_value_heads,
3519            num_attn_heads: cfg.num_attention_heads,
3520            sliding_window: cfg.sliding_window,
3521            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3522            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3523        };
3524
3525        Ok(Box::new(cfg))
3526    }
3527}
3528
3529/// [`NormalLoader`] for a Qwen 3 MoE model.
3530///
3531/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3532pub struct Qwen3MoELoader;
3533
3534impl NormalModelLoader for Qwen3MoELoader {
3535    fn load(
3536        &self,
3537        config: &str,
3538        use_flash_attn: bool,
3539        vb: ShardedVarBuilder,
3540        normal_loading_metadata: NormalLoadingMetadata,
3541        attention_mechanism: AttentionImplementation,
3542    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3543        let mut cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3544        cfg.use_flash_attn = use_flash_attn;
3545
3546        Ok(Box::new(models::qwen3_moe::Model::new(
3547            &cfg,
3548            vb,
3549            self.is_gptx(config)?,
3550            normal_loading_metadata,
3551            attention_mechanism,
3552        )?))
3553    }
3554    fn load_xlora(
3555        &self,
3556        _config: &str,
3557        _use_flash_attn: bool,
3558        _vb: ShardedVarBuilder,
3559        _lora_config: &[((String, String), LoraConfig)],
3560        _xlora_config: Option<XLoraConfig>,
3561        _xlora_ordering: Ordering,
3562        _normal_loading_metadata: NormalLoadingMetadata,
3563        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3564    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3565        todo!()
3566    }
3567    fn is_gptx(&self, _: &str) -> Result<bool> {
3568        Ok(true)
3569    }
3570    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3571        let mut cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3572        cfg.use_flash_attn = use_flash_attn;
3573
3574        Ok(Box::new(cfg))
3575    }
3576}
3577
3578impl IsqModelLoader for Qwen3MoELoader {
3579    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3580        Ok(vec![
3581            Regex::new(r"lm_head\.(weight|bias)$")?,
3582            // Attention
3583            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3584            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3585            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3586            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3587            // MLP
3588            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3589            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3590            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3591            // MLP MoE
3592            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3593            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3594            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3595        ])
3596    }
3597}
3598
3599impl DeviceMappedModelLoader for Qwen3MoELoader {
3600    fn mapped_max_act_size_elems(
3601        &self,
3602        config: &str,
3603        params: &AutoDeviceMapParams,
3604        prompt_chunksize: usize,
3605    ) -> Result<usize> {
3606        let AutoDeviceMapParams::Text {
3607            max_seq_len: _,
3608            max_batch_size,
3609        } = params
3610        else {
3611            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3612        };
3613
3614        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3615
3616        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3617    }
3618    fn non_mapped_max_act_size_elems(
3619        &self,
3620        _config: &str,
3621        _params: &AutoDeviceMapParams,
3622    ) -> Result<usize> {
3623        Ok(0)
3624    }
3625
3626    fn non_mapped_size_in_bytes(
3627        &self,
3628        config: &str,
3629        dtype: DType,
3630        weight_pack_factor: usize,
3631    ) -> Result<usize> {
3632        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3633        let elems = {
3634            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3635            let lm_head = if !cfg.tie_word_embeddings {
3636                cfg.hidden_size * cfg.vocab_size
3637            } else {
3638                0
3639            };
3640            let norm = cfg.hidden_size;
3641            embed_tokens + lm_head + norm
3642        };
3643        Ok(elems * dtype.size_in_bytes())
3644    }
3645
3646    fn layer_sizes_in_bytes(
3647        &self,
3648        config: &str,
3649        dtype: DType,
3650        weight_pack_factor: usize,
3651    ) -> Result<Vec<usize>> {
3652        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3653        let mut layer_sizes_in_bytes = Vec::new();
3654        for layer_idx in 0..cfg.num_hidden_layers {
3655            let input_layernorm = cfg.hidden_size;
3656            let post_attention_layernorm = cfg.hidden_size;
3657
3658            let size_in = cfg.hidden_size;
3659            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3660            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3661            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3662            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3663            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3664            let o_proj = size_q * size_in / weight_pack_factor;
3665
3666            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3667                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3668            {
3669                let expert_size = {
3670                    let h_size = cfg.hidden_size;
3671                    let i_size = cfg.moe_intermediate_size;
3672                    let gate_proj = h_size * i_size / weight_pack_factor;
3673                    let up_proj = h_size * i_size / weight_pack_factor;
3674                    let down_proj = i_size * h_size / weight_pack_factor;
3675                    gate_proj + up_proj + down_proj
3676                };
3677                expert_size * cfg.num_experts
3678            } else {
3679                let h_size = cfg.hidden_size;
3680                let i_size = cfg.intermediate_size;
3681                let gate_proj = h_size * i_size / weight_pack_factor;
3682                let up_proj = h_size * i_size / weight_pack_factor;
3683                let down_proj = i_size * h_size / weight_pack_factor;
3684                gate_proj + up_proj + down_proj
3685            };
3686
3687            let q_norm = cfg.head_dim();
3688            let k_norm = cfg.head_dim();
3689
3690            let size_elems = input_layernorm
3691                + post_attention_layernorm
3692                + q_proj
3693                + k_proj
3694                + v_proj
3695                + o_proj
3696                + mlp_size
3697                + q_norm
3698                + k_norm;
3699
3700            let size_in_bytes = size_elems * dtype.size_in_bytes();
3701            layer_sizes_in_bytes.push(size_in_bytes);
3702        }
3703
3704        Ok(layer_sizes_in_bytes)
3705    }
3706
3707    fn num_layers(&self, config: &str) -> Result<usize> {
3708        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3709        Ok(cfg.num_hidden_layers)
3710    }
3711
3712    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3713        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3714
3715        let cfg = ModelConfigMetadata {
3716            max_seq_len: cfg.max_position_embeddings,
3717            num_layers: cfg.num_hidden_layers,
3718            hidden_size: cfg.hidden_size,
3719            num_kv_heads: cfg.num_key_value_heads,
3720            num_attn_heads: cfg.num_attention_heads,
3721            sliding_window: cfg.sliding_window,
3722            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3723            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3724        };
3725
3726        Ok(Box::new(cfg))
3727    }
3728}