mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    amoe::AnyMoeBaseModelMixin,
10    device_map::DeviceMapper,
11    lora::{LoraConfig, Ordering},
12    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
13    pipeline::{
14        isq::IsqModelLoader,
15        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16        EitherCache, IsqModel,
17    },
18    utils::varbuilder_utils::DeviceForLoadTensor,
19    xlora_models::NonGranularState,
20};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use mistralrs_quant::log::once_log_info;
24
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27#[cfg(feature = "pyo3_macros")]
28use pyo3::pyclass;
29
30use regex::Regex;
31use serde::Deserialize;
32
33use crate::{
34    models,
35    xlora_models::{self, XLoraConfig},
36};
37
38use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
39
40pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
41    #[allow(clippy::too_many_arguments)]
42    fn forward(
43        &self,
44        input_ids: &Tensor,
45        seqlen_offsets: &[usize],
46        context_lens: Vec<(usize, usize)>,
47        position_ids: Vec<usize>,
48        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
49        flash_params: &FlashParams,
50    ) -> candle_core::Result<Tensor>;
51    #[allow(clippy::too_many_arguments)]
52    fn xlora_forward(
53        &self,
54        input_ids: &Tensor,
55        input_ids_full: &Tensor,
56        seqlen_offsets: &[usize],
57        seqlen_offsets_full: &[usize],
58        no_kv_cache: bool,
59        non_granular_state: &Option<NonGranularState>,
60        context_lens: Vec<(usize, usize)>,
61        position_ids: Vec<usize>,
62        flash_params: &FlashParams,
63        flash_params_full: &FlashParams,
64    ) -> candle_core::Result<Tensor>;
65    fn is_xlora(&self) -> bool;
66    fn device(&self) -> &Device;
67    fn cache(&self) -> &EitherCache;
68    fn cache_mut(&mut self) -> &mut EitherCache;
69    fn max_seq_len(&self) -> usize;
70    fn config(&self) -> &ModelConfigMetadata;
71}
72
73/// Metadata for loading a model with ISQ or device mapping.
74pub struct NormalLoadingMetadata {
75    // Device mapping metadata which can be used to construct a concrete device mapper
76    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
77    // Flag to check if loading in ISQ
78    pub loading_isq: bool,
79    // Device mapping target device (the one that is not the cpu)
80    pub real_device: Device,
81    // MultiProgress support for parallelized loading
82    pub multi_progress: Arc<MultiProgress>,
83}
84
85pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
86    fn load(
87        &self,
88        config: &str,
89        vb: ShardedVarBuilder,
90        normal_loading_metadata: NormalLoadingMetadata,
91        attention_mechanism: AttentionImplementation,
92    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
93    #[allow(clippy::too_many_arguments)]
94    fn load_xlora(
95        &self,
96        config: &str,
97        vb: ShardedVarBuilder,
98        lora_config: &[((String, String), LoraConfig)],
99        xlora_config: Option<XLoraConfig>,
100        xlora_ordering: Ordering,
101        normal_loading_metadata: NormalLoadingMetadata,
102        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
103    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
104    fn is_gptx(&self, config: &str) -> Result<bool>;
105    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
106        Ok(true)
107    }
108    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
109    fn get_device_for_tensor(
110        &self,
111        config: &str,
112        _mapper: &dyn DeviceMapper,
113        loading_isq: bool,
114    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115        if loading_isq {
116            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117        } else {
118            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119            let num_layers = self.model_config(config)?.num_layers();
120            let closure = move |name: String| {
121                if let Some(captures) = re.captures(&name) {
122                    captures
123                        .get(1)
124                        .and_then(|m| m.as_str().parse::<usize>().ok())
125                        .map(|l| l.min(num_layers))
126                        .map(DeviceForLoadTensor::Idx)
127                        .unwrap_or(DeviceForLoadTensor::Base)
128                } else {
129                    DeviceForLoadTensor::Base
130                }
131            };
132
133            Ok(Arc::new(closure))
134        }
135    }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140/// The architecture to load the normal model as.
141pub enum NormalLoaderType {
142    #[serde(rename = "mistral")]
143    Mistral,
144    #[serde(rename = "gemma")]
145    Gemma,
146    #[serde(rename = "mixtral")]
147    Mixtral,
148    #[serde(rename = "llama")]
149    Llama,
150    #[serde(rename = "phi2")]
151    Phi2,
152    #[serde(rename = "phi3")]
153    Phi3,
154    #[serde(rename = "qwen2")]
155    Qwen2,
156    #[serde(rename = "gemma2")]
157    Gemma2,
158    #[serde(rename = "starcoder2")]
159    Starcoder2,
160    #[serde(rename = "phi3.5moe")]
161    Phi3_5MoE,
162    #[serde(rename = "deepseekv2")]
163    DeepSeekV2,
164    #[serde(rename = "deepseekv3")]
165    DeepSeekV3,
166    #[serde(rename = "qwen3")]
167    Qwen3,
168    #[serde(rename = "qwen3moe")]
169    Qwen3Moe,
170}
171
172// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
173impl NormalLoaderType {
174    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
175        match name {
176            "MistralForCausalLM" => Ok(Self::Mistral),
177            "MixtralForCausalLM" => Ok(Self::Mixtral),
178            "GemmaForCausalLM" => Ok(Self::Gemma),
179            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
180            "PhiForCausalLM" => Ok(Self::Phi2),
181            "Phi3ForCausalLM" => Ok(Self::Phi3),
182            "LlamaForCausalLM" => Ok(Self::Llama),
183            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
184            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
185            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
186            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
187            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
188            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
189            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
190            other => anyhow::bail!(
191                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
192            ),
193        }
194    }
195}
196
197impl FromStr for NormalLoaderType {
198    type Err = String;
199    fn from_str(s: &str) -> Result<Self, Self::Err> {
200        match s {
201            "mistral" => Ok(Self::Mistral),
202            "gemma" => Ok(Self::Gemma),
203            "mixtral" => Ok(Self::Mixtral),
204            "llama" => Ok(Self::Llama),
205            "phi2" => Ok(Self::Phi2),
206            "phi3" => Ok(Self::Phi3),
207            "qwen2" => Ok(Self::Qwen2),
208            "gemma2" => Ok(Self::Gemma2),
209            "starcoder2" => Ok(Self::Starcoder2),
210            "phi3.5moe" => Ok(Self::Phi3_5MoE),
211            "deepseekv2" => Ok(Self::DeepSeekV2),
212            "deepseekv3" => Ok(Self::DeepSeekV3),
213            "qwen3" => Ok(Self::DeepSeekV3),
214            "qwen3moe" => Ok(Self::Qwen3Moe),
215            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `qwen3moe`.")),
216        }
217    }
218}
219
220impl Display for NormalLoaderType {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            Self::Gemma => write!(f, "gemma"),
224            Self::Gemma2 => write!(f, "gemma2"),
225            Self::Llama => write!(f, "llama"),
226            Self::Mistral => write!(f, "mistral"),
227            Self::Mixtral => write!(f, "mixtral"),
228            Self::Phi2 => write!(f, "phi2"),
229            Self::Phi3 => write!(f, "phi3"),
230            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
231            Self::Qwen2 => write!(f, "qwen2"),
232            Self::Starcoder2 => write!(f, "starcoder2"),
233            Self::DeepSeekV2 => write!(f, "deepseekv2"),
234            Self::DeepSeekV3 => write!(f, "deepseekv3"),
235            Self::Qwen3 => write!(f, "qwen3"),
236            Self::Qwen3Moe => write!(f, "qwen3moe"),
237        }
238    }
239}
240
241macro_rules! bias_if {
242    ($cond:expr, $size:expr) => {
243        if $cond {
244            $size
245        } else {
246            0
247        }
248    };
249}
250
251/// Load a model based on the Hugging Face Transformers -CausalLM model class
252pub struct AutoNormalLoader;
253
254#[derive(Deserialize)]
255struct AutoNormalLoaderConfig {
256    architectures: Vec<String>,
257}
258
259impl AutoNormalLoader {
260    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
261        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
262        if auto_cfg.architectures.len() != 1 {
263            anyhow::bail!("Expected to have one name for `architectures` config field.")
264        }
265
266        let name = &auto_cfg.architectures[0];
267
268        let tp = NormalLoaderType::from_causal_lm_name(name)?;
269
270        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
271
272        match tp {
273            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
274            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
275            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
276            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
277            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
278            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
279            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
280            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
281            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
282            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
283            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
284            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
285            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
286            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
287        }
288    }
289}
290
291impl NormalModelLoader for AutoNormalLoader {
292    fn load(
293        &self,
294        config: &str,
295        vb: ShardedVarBuilder,
296        normal_loading_metadata: NormalLoadingMetadata,
297        attention_mechanism: AttentionImplementation,
298    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
299        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
300    }
301    fn load_xlora(
302        &self,
303        config: &str,
304        vb: ShardedVarBuilder,
305        lora_config: &[((String, String), LoraConfig)],
306        xlora_config: Option<XLoraConfig>,
307        xlora_ordering: Ordering,
308        normal_loading_metadata: NormalLoadingMetadata,
309        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
310    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
311        Self::get_loader(config)?.load_xlora(
312            config,
313            vb,
314            lora_config,
315            xlora_config,
316            xlora_ordering,
317            normal_loading_metadata,
318            preload_adapters,
319        )
320    }
321    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
322        Self::get_loader(config)?.get_config_repr(config)
323    }
324    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
325        Self::get_loader(config)?.supports_paged_attention(config)
326    }
327    fn is_gptx(&self, config: &str) -> Result<bool> {
328        Self::get_loader(config)?.is_gptx(config)
329    }
330}
331
332impl IsqModelLoader for AutoNormalLoader {
333    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
334        Self::get_loader(config)?.immediate_isq_predicates(config)
335    }
336    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
337        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
338    }
339    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
340        Self::get_loader(config)?.isq_layer_regexes(config)
341    }
342    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
343        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
344    }
345}
346
347impl DeviceMappedModelLoader for AutoNormalLoader {
348    fn non_mapped_size_in_bytes(
349        &self,
350        config: &str,
351        dtype: DType,
352        weight_pack_factor: usize,
353    ) -> Result<usize> {
354        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
355    }
356    fn num_layers(&self, config: &str) -> Result<usize> {
357        Self::get_loader(config)?.num_layers(config)
358    }
359    fn layer_sizes_in_bytes(
360        &self,
361        config: &str,
362        dtype: DType,
363        weight_pack_factor: usize,
364    ) -> Result<Vec<usize>> {
365        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
366    }
367    fn mapped_max_act_size_elems(
368        &self,
369        config: &str,
370        params: &super::AutoDeviceMapParams,
371        prompt_chunksize: usize,
372    ) -> Result<usize> {
373        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
374    }
375    fn non_mapped_max_act_size_elems(
376        &self,
377        _config: &str,
378        _params: &AutoDeviceMapParams,
379    ) -> Result<usize> {
380        Ok(0)
381    }
382    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
383        Self::get_loader(config)?.model_config(config)
384    }
385}
386
387// ======================== Mistral loader
388
389pub struct MistralLoader;
390
391impl NormalModelLoader for MistralLoader {
392    fn load(
393        &self,
394        config: &str,
395        vb: ShardedVarBuilder,
396        normal_loading_metadata: NormalLoadingMetadata,
397        attention_mechanism: AttentionImplementation,
398    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
399        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
400        Ok(Box::new(models::mistral::Model::new(
401            &cfg,
402            vb,
403            self.is_gptx(config)?,
404            normal_loading_metadata,
405            attention_mechanism,
406        )?))
407    }
408    fn load_xlora(
409        &self,
410        config: &str,
411        vb: ShardedVarBuilder,
412        lora_config: &[((String, String), LoraConfig)],
413        xlora_config: Option<XLoraConfig>,
414        xlora_ordering: Ordering,
415        normal_loading_metadata: NormalLoadingMetadata,
416        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
417    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
418        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
419        Ok(Box::new(xlora_models::XLoraMistral::new(
420            &cfg,
421            vb,
422            lora_config,
423            xlora_config,
424            xlora_ordering,
425            self.is_gptx(config)?,
426            normal_loading_metadata,
427            preload_adapters,
428        )?))
429    }
430    fn is_gptx(&self, _: &str) -> Result<bool> {
431        Ok(true)
432    }
433    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
434        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
435        Ok(Box::new(cfg))
436    }
437}
438
439impl IsqModelLoader for MistralLoader {
440    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
441        Ok(vec![
442            Regex::new(r"lm_head\.(weight|bias)$")?,
443            // Attention
444            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
445            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
446            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
447            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
448            // MLP
449            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
450            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
451            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
452        ])
453    }
454    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
455        self.isq_layer_regexes(config)
456    }
457}
458
459impl DeviceMappedModelLoader for MistralLoader {
460    fn mapped_max_act_size_elems(
461        &self,
462        config: &str,
463        params: &AutoDeviceMapParams,
464        prompt_chunksize: usize,
465    ) -> Result<usize> {
466        let AutoDeviceMapParams::Text {
467            max_seq_len: _,
468            max_batch_size,
469        } = params
470        else {
471            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
472        };
473
474        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
475
476        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
477    }
478    fn non_mapped_max_act_size_elems(
479        &self,
480        _config: &str,
481        _params: &AutoDeviceMapParams,
482    ) -> Result<usize> {
483        Ok(0)
484    }
485
486    fn non_mapped_size_in_bytes(
487        &self,
488        config: &str,
489        dtype: DType,
490        weight_pack_factor: usize,
491    ) -> Result<usize> {
492        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
493
494        let elems = {
495            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
496            let lm_head = if !cfg.tie_word_embeddings {
497                cfg.hidden_size * cfg.vocab_size
498            } else {
499                0
500            };
501            let norm = cfg.hidden_size;
502            embed_tokens + lm_head + norm
503        };
504        Ok(elems * dtype.size_in_bytes())
505    }
506
507    fn layer_sizes_in_bytes(
508        &self,
509        config: &str,
510        dtype: DType,
511        weight_pack_factor: usize,
512    ) -> Result<Vec<usize>> {
513        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
514
515        let per_layer_elems = {
516            let input_layernorm = cfg.hidden_size;
517            let post_attention_layernorm = cfg.hidden_size;
518
519            let size_in = cfg.hidden_size;
520            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
521            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
522            let q_proj = size_in * size_q / weight_pack_factor;
523            let k_proj = size_in * size_kv / weight_pack_factor;
524            let v_proj = size_in * size_kv / weight_pack_factor;
525            let o_proj = size_q * size_in / weight_pack_factor;
526
527            let h_size = cfg.hidden_size;
528            let i_size = cfg.intermediate_size;
529            let gate_proj = h_size * i_size / weight_pack_factor;
530            let up_proj = h_size * i_size / weight_pack_factor;
531            let down_proj = i_size * h_size / weight_pack_factor;
532
533            input_layernorm
534                + post_attention_layernorm
535                + q_proj
536                + k_proj
537                + v_proj
538                + o_proj
539                + gate_proj
540                + up_proj
541                + down_proj
542        };
543        Ok(vec![
544            per_layer_elems * dtype.size_in_bytes();
545            cfg.num_hidden_layers
546        ])
547    }
548
549    fn num_layers(&self, config: &str) -> Result<usize> {
550        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
551        Ok(cfg.num_hidden_layers)
552    }
553
554    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
555        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
556
557        let cfg = ModelConfigMetadata {
558            max_seq_len: cfg.max_position_embeddings,
559            num_layers: cfg.num_hidden_layers,
560            hidden_size: cfg.hidden_size,
561            num_kv_heads: cfg.num_key_value_heads,
562            num_attn_heads: cfg.num_attention_heads,
563            sliding_window: cfg.sliding_window,
564            k_head_dim: cfg.head_dim(),
565            v_head_dim: cfg.head_dim(),
566        };
567
568        Ok(Box::new(cfg))
569    }
570}
571
572// ======================== Gemma loader
573
574/// [`NormalLoader`] for a Gemma model.
575///
576/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
577pub struct GemmaLoader;
578
579impl NormalModelLoader for GemmaLoader {
580    fn load(
581        &self,
582        config: &str,
583        vb: ShardedVarBuilder,
584        normal_loading_metadata: NormalLoadingMetadata,
585        attention_mechanism: AttentionImplementation,
586    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
587        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
588
589        Ok(Box::new(models::gemma::Model::new(
590            &cfg,
591            vb,
592            self.is_gptx(config)?,
593            normal_loading_metadata,
594            attention_mechanism,
595        )?))
596    }
597    fn load_xlora(
598        &self,
599        config: &str,
600        vb: ShardedVarBuilder,
601        lora_config: &[((String, String), LoraConfig)],
602        xlora_config: Option<XLoraConfig>,
603        xlora_ordering: Ordering,
604        normal_loading_metadata: NormalLoadingMetadata,
605        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
606    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
607        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
608
609        Ok(Box::new(xlora_models::XLoraGemma::new(
610            &cfg,
611            vb,
612            lora_config,
613            xlora_config,
614            xlora_ordering,
615            self.is_gptx(config)?,
616            normal_loading_metadata,
617            preload_adapters,
618        )?))
619    }
620    fn is_gptx(&self, _: &str) -> Result<bool> {
621        Ok(true)
622    }
623    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
624        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
625        Ok(Box::new(cfg))
626    }
627}
628
629impl IsqModelLoader for GemmaLoader {
630    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
631        Ok(vec![
632            Regex::new(r"lm_head\.(weight|bias)$")?,
633            // Attention
634            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
635            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
636            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
637            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
638            // MLP
639            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
640            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
641            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
642        ])
643    }
644    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
645        self.isq_layer_regexes(config)
646    }
647}
648
649impl DeviceMappedModelLoader for GemmaLoader {
650    fn mapped_max_act_size_elems(
651        &self,
652        config: &str,
653        params: &AutoDeviceMapParams,
654        prompt_chunksize: usize,
655    ) -> Result<usize> {
656        let AutoDeviceMapParams::Text {
657            max_seq_len: _,
658            max_batch_size,
659        } = params
660        else {
661            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
662        };
663
664        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
665
666        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
667    }
668    fn non_mapped_max_act_size_elems(
669        &self,
670        _config: &str,
671        _params: &AutoDeviceMapParams,
672    ) -> Result<usize> {
673        Ok(0)
674    }
675
676    fn non_mapped_size_in_bytes(
677        &self,
678        config: &str,
679        dtype: DType,
680        weight_pack_factor: usize,
681    ) -> Result<usize> {
682        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
683
684        let elems = {
685            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
686            let lm_head = if !cfg.tie_word_embeddings {
687                cfg.hidden_size * cfg.vocab_size
688            } else {
689                0
690            };
691            let norm = cfg.hidden_size;
692            embed_tokens + lm_head + norm
693        };
694        Ok(elems * dtype.size_in_bytes())
695    }
696
697    fn layer_sizes_in_bytes(
698        &self,
699        config: &str,
700        dtype: DType,
701        weight_pack_factor: usize,
702    ) -> Result<Vec<usize>> {
703        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
704
705        let per_layer_elems = {
706            let input_layernorm = cfg.hidden_size;
707            let post_attention_layernorm = cfg.hidden_size;
708
709            let size_in = cfg.hidden_size;
710            let size_q = cfg.head_dim * cfg.num_attention_heads;
711            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
712            let q_proj =
713                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
714            let k_proj =
715                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
716            let v_proj =
717                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
718            let o_proj =
719                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
720
721            let h_size = cfg.hidden_size;
722            let i_size = cfg.intermediate_size;
723            let gate_proj = h_size * i_size / weight_pack_factor;
724            let up_proj = h_size * i_size / weight_pack_factor;
725            let down_proj = i_size * h_size / weight_pack_factor;
726
727            input_layernorm
728                + post_attention_layernorm
729                + q_proj
730                + k_proj
731                + v_proj
732                + o_proj
733                + gate_proj
734                + up_proj
735                + down_proj
736        };
737        Ok(vec![
738            per_layer_elems * dtype.size_in_bytes();
739            cfg.num_hidden_layers
740        ])
741    }
742
743    fn num_layers(&self, config: &str) -> Result<usize> {
744        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
745        Ok(cfg.num_hidden_layers)
746    }
747
748    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
749        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
750
751        let cfg = ModelConfigMetadata {
752            max_seq_len: cfg.max_position_embeddings,
753            num_layers: cfg.num_hidden_layers,
754            hidden_size: cfg.hidden_size,
755            num_kv_heads: cfg.num_key_value_heads,
756            num_attn_heads: cfg.num_attention_heads,
757            sliding_window: None,
758            k_head_dim: cfg.head_dim,
759            v_head_dim: cfg.head_dim,
760        };
761
762        Ok(Box::new(cfg))
763    }
764}
765
766// ======================== Llama loader
767
768/// [`NormalLoader`] for a Llama model.
769///
770/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
771pub struct LlamaLoader;
772
773impl NormalModelLoader for LlamaLoader {
774    fn load(
775        &self,
776        config: &str,
777        vb: ShardedVarBuilder,
778        normal_loading_metadata: NormalLoadingMetadata,
779        attention_mechanism: AttentionImplementation,
780    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
781        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
782
783        Ok(Box::new(models::llama::Llama::new(
784            &cfg,
785            vb,
786            self.is_gptx(config)?,
787            normal_loading_metadata,
788            attention_mechanism,
789        )?))
790    }
791    fn load_xlora(
792        &self,
793        config: &str,
794        vb: ShardedVarBuilder,
795        lora_config: &[((String, String), LoraConfig)],
796        xlora_config: Option<XLoraConfig>,
797        xlora_ordering: Ordering,
798        normal_loading_metadata: NormalLoadingMetadata,
799        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
800    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
801        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
802
803        Ok(Box::new(xlora_models::XLoraLlama::new(
804            &cfg,
805            vb,
806            lora_config,
807            xlora_config,
808            xlora_ordering,
809            self.is_gptx(config)?,
810            normal_loading_metadata,
811            preload_adapters,
812        )?))
813    }
814    fn is_gptx(&self, _: &str) -> Result<bool> {
815        Ok(true)
816    }
817    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
818        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
819        Ok(Box::new(cfg))
820    }
821}
822
823impl IsqModelLoader for LlamaLoader {
824    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
825        Ok(vec![
826            Regex::new(r"lm_head\.(weight|bias)$")?,
827            // Attention
828            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
829            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
830            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
831            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
832            // MLP
833            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
834            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
835            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
836        ])
837    }
838    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
839        self.isq_layer_regexes(config)
840    }
841}
842
843impl DeviceMappedModelLoader for LlamaLoader {
844    fn mapped_max_act_size_elems(
845        &self,
846        config: &str,
847        params: &AutoDeviceMapParams,
848        prompt_chunksize: usize,
849    ) -> Result<usize> {
850        let AutoDeviceMapParams::Text {
851            max_seq_len: _,
852            max_batch_size,
853        } = params
854        else {
855            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
856        };
857
858        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
859
860        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
861    }
862    fn non_mapped_max_act_size_elems(
863        &self,
864        _config: &str,
865        _params: &AutoDeviceMapParams,
866    ) -> Result<usize> {
867        Ok(0)
868    }
869
870    fn non_mapped_size_in_bytes(
871        &self,
872        config: &str,
873        dtype: DType,
874        weight_pack_factor: usize,
875    ) -> Result<usize> {
876        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
877
878        let elems = {
879            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
880            let lm_head = if !cfg.tie_word_embeddings {
881                cfg.hidden_size * cfg.vocab_size
882            } else {
883                0
884            };
885            let norm = cfg.hidden_size;
886            embed_tokens + lm_head + norm
887        };
888        Ok(elems * dtype.size_in_bytes())
889    }
890
891    fn layer_sizes_in_bytes(
892        &self,
893        config: &str,
894        dtype: DType,
895        weight_pack_factor: usize,
896    ) -> Result<Vec<usize>> {
897        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
898
899        let per_layer_elems = {
900            let input_layernorm = cfg.hidden_size;
901            let post_attention_layernorm = cfg.hidden_size;
902
903            let size_in = cfg.hidden_size;
904            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
905            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
906            let q_proj = size_in * size_q / weight_pack_factor;
907            let k_proj = size_in * size_kv / weight_pack_factor;
908            let v_proj = size_in * size_kv / weight_pack_factor;
909            let o_proj = size_q * size_in / weight_pack_factor;
910
911            let h_size = cfg.hidden_size;
912            let i_size = cfg.intermediate_size;
913            let gate_proj = h_size * i_size / weight_pack_factor;
914            let up_proj = h_size * i_size / weight_pack_factor;
915            let down_proj = i_size * h_size / weight_pack_factor;
916
917            input_layernorm
918                + post_attention_layernorm
919                + q_proj
920                + k_proj
921                + v_proj
922                + o_proj
923                + gate_proj
924                + up_proj
925                + down_proj
926        };
927        Ok(vec![
928            per_layer_elems * dtype.size_in_bytes();
929            cfg.num_hidden_layers
930        ])
931    }
932
933    fn num_layers(&self, config: &str) -> Result<usize> {
934        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
935
936        Ok(cfg.num_hidden_layers)
937    }
938    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
939        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
940
941        let cfg = ModelConfigMetadata {
942            max_seq_len: cfg.max_position_embeddings,
943            num_layers: cfg.num_hidden_layers,
944            hidden_size: cfg.hidden_size,
945            num_kv_heads: cfg.num_key_value_heads,
946            num_attn_heads: cfg.num_attention_heads,
947            sliding_window: None,
948            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
949            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
950        };
951
952        Ok(Box::new(cfg))
953    }
954}
955
956// ======================== Mixtral loader
957
958pub struct MixtralLoader;
959
960impl NormalModelLoader for MixtralLoader {
961    fn load(
962        &self,
963        config: &str,
964        vb: ShardedVarBuilder,
965        normal_loading_metadata: NormalLoadingMetadata,
966        attention_mechanism: AttentionImplementation,
967    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
968        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
969
970        Ok(Box::new(models::mixtral::Model::new(
971            &cfg,
972            vb,
973            self.is_gptx(config)?,
974            normal_loading_metadata,
975            attention_mechanism,
976        )?))
977    }
978    fn load_xlora(
979        &self,
980        config: &str,
981        vb: ShardedVarBuilder,
982        lora_config: &[((String, String), LoraConfig)],
983        xlora_config: Option<XLoraConfig>,
984        xlora_ordering: Ordering,
985        normal_loading_metadata: NormalLoadingMetadata,
986        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
987    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
988        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
989
990        Ok(Box::new(xlora_models::XLoraMixtral::new(
991            &cfg,
992            vb,
993            lora_config,
994            xlora_config,
995            xlora_ordering,
996            self.is_gptx(config)?,
997            normal_loading_metadata,
998            preload_adapters,
999        )?))
1000    }
1001    fn is_gptx(&self, _: &str) -> Result<bool> {
1002        Ok(true)
1003    }
1004    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1005        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1006
1007        Ok(Box::new(cfg))
1008    }
1009}
1010
1011impl IsqModelLoader for MixtralLoader {
1012    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1013        Ok(vec![
1014            Regex::new(r"lm_head\.(weight|bias)$")?,
1015            // Attention
1016            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1017            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1018            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1019            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1020            // Experts
1021            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1022            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1023            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1024            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1025        ])
1026    }
1027    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1028        self.isq_layer_regexes(config)
1029    }
1030}
1031
1032impl DeviceMappedModelLoader for MixtralLoader {
1033    fn mapped_max_act_size_elems(
1034        &self,
1035        config: &str,
1036        params: &AutoDeviceMapParams,
1037        prompt_chunksize: usize,
1038    ) -> Result<usize> {
1039        let AutoDeviceMapParams::Text {
1040            max_seq_len: _,
1041            max_batch_size,
1042        } = params
1043        else {
1044            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1045        };
1046
1047        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1048
1049        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1050    }
1051    fn non_mapped_max_act_size_elems(
1052        &self,
1053        _config: &str,
1054        _params: &AutoDeviceMapParams,
1055    ) -> Result<usize> {
1056        Ok(0)
1057    }
1058
1059    fn non_mapped_size_in_bytes(
1060        &self,
1061        config: &str,
1062        dtype: DType,
1063        weight_pack_factor: usize,
1064    ) -> Result<usize> {
1065        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1066
1067        let elems = {
1068            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1069            let lm_head = if !cfg.tie_word_embeddings {
1070                cfg.hidden_size * cfg.vocab_size
1071            } else {
1072                0
1073            };
1074            let norm = cfg.hidden_size;
1075            embed_tokens + lm_head + norm
1076        };
1077        Ok(elems * dtype.size_in_bytes())
1078    }
1079
1080    fn layer_sizes_in_bytes(
1081        &self,
1082        config: &str,
1083        dtype: DType,
1084        weight_pack_factor: usize,
1085    ) -> Result<Vec<usize>> {
1086        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1087
1088        let per_layer_elems = {
1089            let input_layernorm = cfg.hidden_size;
1090            let post_attention_layernorm = cfg.hidden_size;
1091
1092            let size_in = cfg.hidden_size;
1093            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1094            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1095            let q_proj = size_in * size_q / weight_pack_factor;
1096            let k_proj = size_in * size_kv / weight_pack_factor;
1097            let v_proj = size_in * size_kv / weight_pack_factor;
1098            let o_proj = size_q * size_in / weight_pack_factor;
1099
1100            let moe_block = {
1101                let gate = cfg.hidden_size * cfg.num_local_experts;
1102                // Assume quantizing weight pack factor
1103                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1104                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1105                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1106                gate + cfg.num_local_experts * w1
1107                    + cfg.num_local_experts * w2
1108                    + cfg.num_local_experts * w3
1109            };
1110
1111            input_layernorm
1112                + post_attention_layernorm
1113                + q_proj
1114                + k_proj
1115                + v_proj
1116                + o_proj
1117                + moe_block
1118        };
1119        Ok(vec![
1120            per_layer_elems * dtype.size_in_bytes();
1121            cfg.num_hidden_layers
1122        ])
1123    }
1124
1125    fn num_layers(&self, config: &str) -> Result<usize> {
1126        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128        Ok(cfg.num_hidden_layers)
1129    }
1130
1131    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1132        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1133
1134        let cfg = ModelConfigMetadata {
1135            max_seq_len: cfg.max_position_embeddings,
1136            num_layers: cfg.num_hidden_layers,
1137            hidden_size: cfg.hidden_size,
1138            num_kv_heads: cfg.num_key_value_heads,
1139            num_attn_heads: cfg.num_attention_heads,
1140            sliding_window: cfg.sliding_window,
1141            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1142            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1143        };
1144
1145        Ok(Box::new(cfg))
1146    }
1147}
1148
1149// ======================== Phi2 loader
1150
1151/// [`NormalLoader`] for a Phi 2 model.
1152///
1153/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1154pub struct Phi2Loader;
1155
1156impl NormalModelLoader for Phi2Loader {
1157    fn load(
1158        &self,
1159        config: &str,
1160        vb: ShardedVarBuilder,
1161        normal_loading_metadata: NormalLoadingMetadata,
1162        attention_mechanism: AttentionImplementation,
1163    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1164        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1165
1166        Ok(Box::new(models::phi2::Model::new(
1167            &cfg,
1168            vb,
1169            self.is_gptx(config)?,
1170            normal_loading_metadata,
1171            attention_mechanism,
1172        )?))
1173    }
1174    fn load_xlora(
1175        &self,
1176        config: &str,
1177        vb: ShardedVarBuilder,
1178        lora_config: &[((String, String), LoraConfig)],
1179        xlora_config: Option<XLoraConfig>,
1180        xlora_ordering: Ordering,
1181        normal_loading_metadata: NormalLoadingMetadata,
1182        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1183    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1184        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1185
1186        Ok(Box::new(xlora_models::XLoraPhi2::new(
1187            &cfg,
1188            vb,
1189            lora_config,
1190            xlora_config,
1191            xlora_ordering,
1192            self.is_gptx(config)?,
1193            normal_loading_metadata,
1194            preload_adapters,
1195        )?))
1196    }
1197    fn is_gptx(&self, _: &str) -> Result<bool> {
1198        Ok(true)
1199    }
1200    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1201        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1202
1203        Ok(Box::new(cfg))
1204    }
1205}
1206
1207impl IsqModelLoader for Phi2Loader {
1208    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1209        Ok(vec![
1210            Regex::new(r"lm_head\.(weight|bias)$")?,
1211            // Attention
1212            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1213            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1214            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1215            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1216            // MLP
1217            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1218            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1219        ])
1220    }
1221    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1222        self.isq_layer_regexes(config)
1223    }
1224}
1225
1226impl DeviceMappedModelLoader for Phi2Loader {
1227    fn mapped_max_act_size_elems(
1228        &self,
1229        config: &str,
1230        params: &AutoDeviceMapParams,
1231        prompt_chunksize: usize,
1232    ) -> Result<usize> {
1233        let AutoDeviceMapParams::Text {
1234            max_seq_len: _,
1235            max_batch_size,
1236        } = params
1237        else {
1238            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1239        };
1240
1241        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1244    }
1245    fn non_mapped_max_act_size_elems(
1246        &self,
1247        _config: &str,
1248        _params: &AutoDeviceMapParams,
1249    ) -> Result<usize> {
1250        Ok(0)
1251    }
1252
1253    fn non_mapped_size_in_bytes(
1254        &self,
1255        config: &str,
1256        dtype: DType,
1257        weight_pack_factor: usize,
1258    ) -> Result<usize> {
1259        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1260
1261        let elems = {
1262            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1263            let lm_head = if !cfg.tie_word_embeddings {
1264                cfg.hidden_size * cfg.vocab_size
1265            } else {
1266                0
1267            };
1268            let norm = cfg.hidden_size;
1269            embed_tokens + lm_head + norm
1270        };
1271        Ok(elems * dtype.size_in_bytes())
1272    }
1273
1274    fn layer_sizes_in_bytes(
1275        &self,
1276        config: &str,
1277        dtype: DType,
1278        weight_pack_factor: usize,
1279    ) -> Result<Vec<usize>> {
1280        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1281
1282        let per_layer_elems = {
1283            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1284
1285            let size_in = cfg.hidden_size;
1286            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1287            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1288            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1289            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1290            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1291            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1292            let (q_norm, k_norm) = if cfg.qk_layernorm {
1293                (cfg.head_dim(), cfg.head_dim())
1294            } else {
1295                (0, 0)
1296            };
1297
1298            let h_size = cfg.hidden_size;
1299            let i_size = cfg.intermediate_size;
1300            let fc1 = h_size * i_size / weight_pack_factor;
1301            let fc2 = h_size * i_size / weight_pack_factor;
1302
1303            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1304        };
1305        Ok(vec![
1306            per_layer_elems * dtype.size_in_bytes();
1307            cfg.num_hidden_layers
1308        ])
1309    }
1310
1311    fn num_layers(&self, config: &str) -> Result<usize> {
1312        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1313
1314        Ok(cfg.num_hidden_layers)
1315    }
1316
1317    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1318        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1319
1320        let cfg = ModelConfigMetadata {
1321            max_seq_len: cfg.max_position_embeddings,
1322            num_layers: cfg.num_hidden_layers,
1323            hidden_size: cfg.hidden_size,
1324            num_kv_heads: cfg.num_key_value_heads(),
1325            num_attn_heads: cfg.num_attention_heads,
1326            sliding_window: None,
1327            k_head_dim: cfg.head_dim(),
1328            v_head_dim: cfg.head_dim(),
1329        };
1330
1331        Ok(Box::new(cfg))
1332    }
1333}
1334
1335// ======================== Phi3 loader
1336
1337/// [`NormalLoader`] for a Phi 3 model.
1338///
1339/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1340pub struct Phi3Loader;
1341
1342impl NormalModelLoader for Phi3Loader {
1343    fn load(
1344        &self,
1345        config: &str,
1346        vb: ShardedVarBuilder,
1347        normal_loading_metadata: NormalLoadingMetadata,
1348        attention_mechanism: AttentionImplementation,
1349    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1350        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1351
1352        Ok(Box::new(models::phi3::Model::new(
1353            &cfg,
1354            vb,
1355            self.is_gptx(config)?,
1356            normal_loading_metadata,
1357            attention_mechanism,
1358        )?))
1359    }
1360    fn load_xlora(
1361        &self,
1362        config: &str,
1363        vb: ShardedVarBuilder,
1364        lora_config: &[((String, String), LoraConfig)],
1365        xlora_config: Option<XLoraConfig>,
1366        xlora_ordering: Ordering,
1367        normal_loading_metadata: NormalLoadingMetadata,
1368        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1369    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1370        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1371
1372        Ok(Box::new(xlora_models::XLoraPhi3::new(
1373            &cfg,
1374            vb,
1375            lora_config,
1376            xlora_config,
1377            xlora_ordering,
1378            self.is_gptx(config)?,
1379            normal_loading_metadata,
1380            preload_adapters,
1381        )?))
1382    }
1383    fn is_gptx(&self, _: &str) -> Result<bool> {
1384        Ok(true)
1385    }
1386    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1387        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1388
1389        Ok(Box::new(cfg))
1390    }
1391}
1392
1393impl IsqModelLoader for Phi3Loader {
1394    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1395        Ok(vec![
1396            Regex::new(r"lm_head\.(weight|bias)$")?,
1397            // Attention
1398            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1399            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1400            // MLP
1401            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1402            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1403            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1404        ])
1405    }
1406    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1407        self.isq_layer_regexes(config)
1408    }
1409}
1410
1411impl DeviceMappedModelLoader for Phi3Loader {
1412    fn mapped_max_act_size_elems(
1413        &self,
1414        config: &str,
1415        params: &AutoDeviceMapParams,
1416        prompt_chunksize: usize,
1417    ) -> Result<usize> {
1418        let AutoDeviceMapParams::Text {
1419            max_seq_len: _,
1420            max_batch_size,
1421        } = params
1422        else {
1423            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1424        };
1425
1426        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1427
1428        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1429    }
1430    fn non_mapped_max_act_size_elems(
1431        &self,
1432        _config: &str,
1433        _params: &AutoDeviceMapParams,
1434    ) -> Result<usize> {
1435        Ok(0)
1436    }
1437
1438    fn non_mapped_size_in_bytes(
1439        &self,
1440        config: &str,
1441        dtype: DType,
1442        weight_pack_factor: usize,
1443    ) -> Result<usize> {
1444        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1445
1446        let elems = {
1447            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1448            let lm_head = if !cfg.tie_word_embeddings {
1449                cfg.hidden_size * cfg.vocab_size
1450            } else {
1451                0
1452            };
1453            let norm = cfg.hidden_size;
1454            embed_tokens + lm_head + norm
1455        };
1456        Ok(elems * dtype.size_in_bytes())
1457    }
1458
1459    fn layer_sizes_in_bytes(
1460        &self,
1461        config: &str,
1462        dtype: DType,
1463        weight_pack_factor: usize,
1464    ) -> Result<Vec<usize>> {
1465        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1466
1467        let per_layer_elems = {
1468            let input_layernorm = cfg.hidden_size;
1469            let post_attention_layernorm = cfg.hidden_size;
1470
1471            let size_in = cfg.hidden_size;
1472            let head_dim = cfg.head_dim();
1473            let op_size =
1474                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1475            let qkv_proj = size_in * op_size / weight_pack_factor;
1476            let o_proj =
1477                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1478
1479            let h_size = cfg.hidden_size;
1480            let i_size = cfg.intermediate_size;
1481            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1482            let down_proj = h_size * i_size / weight_pack_factor;
1483
1484            input_layernorm
1485                + post_attention_layernorm
1486                + qkv_proj
1487                + o_proj
1488                + gate_up_proj
1489                + down_proj
1490        };
1491        Ok(vec![
1492            per_layer_elems * dtype.size_in_bytes();
1493            cfg.num_hidden_layers
1494        ])
1495    }
1496
1497    fn num_layers(&self, config: &str) -> Result<usize> {
1498        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1499
1500        Ok(cfg.num_hidden_layers)
1501    }
1502
1503    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1504        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1505
1506        let cfg = ModelConfigMetadata {
1507            max_seq_len: cfg.max_position_embeddings,
1508            num_layers: cfg.num_hidden_layers,
1509            hidden_size: cfg.hidden_size,
1510            num_kv_heads: cfg.num_key_value_heads,
1511            num_attn_heads: cfg.num_attention_heads,
1512            sliding_window: cfg.sliding_window,
1513            k_head_dim: cfg.head_dim(),
1514            v_head_dim: cfg.head_dim(),
1515        };
1516
1517        Ok(Box::new(cfg))
1518    }
1519}
1520
1521// ======================== Qwen2 loader
1522
1523/// [`NormalLoader`] for a Qwen 2 model.
1524///
1525/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1526pub struct Qwen2Loader;
1527
1528impl NormalModelLoader for Qwen2Loader {
1529    fn load(
1530        &self,
1531        config: &str,
1532        vb: ShardedVarBuilder,
1533        normal_loading_metadata: NormalLoadingMetadata,
1534        attention_mechanism: AttentionImplementation,
1535    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1536        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1537
1538        Ok(Box::new(models::qwen2::Model::new(
1539            &cfg,
1540            vb,
1541            self.is_gptx(config)?,
1542            normal_loading_metadata,
1543            attention_mechanism,
1544        )?))
1545    }
1546    fn load_xlora(
1547        &self,
1548        _config: &str,
1549        _vb: ShardedVarBuilder,
1550        _lora_config: &[((String, String), LoraConfig)],
1551        _xlora_config: Option<XLoraConfig>,
1552        _xlora_ordering: Ordering,
1553        _normal_loading_metadata: NormalLoadingMetadata,
1554        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1555    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1556        todo!()
1557    }
1558    fn is_gptx(&self, _: &str) -> Result<bool> {
1559        Ok(true)
1560    }
1561    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1562        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1563
1564        Ok(Box::new(cfg))
1565    }
1566}
1567
1568impl IsqModelLoader for Qwen2Loader {
1569    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1570        Ok(vec![
1571            Regex::new(r"lm_head\.(weight|bias)$")?,
1572            // Attention
1573            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1574            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1575            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1576            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1577            // MLP
1578            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1579            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1580            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1581        ])
1582    }
1583    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1584        self.isq_layer_regexes(config)
1585    }
1586}
1587
1588impl DeviceMappedModelLoader for Qwen2Loader {
1589    fn mapped_max_act_size_elems(
1590        &self,
1591        config: &str,
1592        params: &AutoDeviceMapParams,
1593        prompt_chunksize: usize,
1594    ) -> Result<usize> {
1595        let AutoDeviceMapParams::Text {
1596            max_seq_len: _,
1597            max_batch_size,
1598        } = params
1599        else {
1600            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1601        };
1602
1603        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1604
1605        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1606    }
1607    fn non_mapped_max_act_size_elems(
1608        &self,
1609        _config: &str,
1610        _params: &AutoDeviceMapParams,
1611    ) -> Result<usize> {
1612        Ok(0)
1613    }
1614
1615    fn non_mapped_size_in_bytes(
1616        &self,
1617        config: &str,
1618        dtype: DType,
1619        weight_pack_factor: usize,
1620    ) -> Result<usize> {
1621        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1622
1623        let elems = {
1624            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1625            let lm_head = if !cfg.tie_word_embeddings {
1626                cfg.hidden_size * cfg.vocab_size
1627            } else {
1628                0
1629            };
1630            let norm = cfg.hidden_size;
1631            embed_tokens + lm_head + norm
1632        };
1633        Ok(elems * dtype.size_in_bytes())
1634    }
1635
1636    fn layer_sizes_in_bytes(
1637        &self,
1638        config: &str,
1639        dtype: DType,
1640        weight_pack_factor: usize,
1641    ) -> Result<Vec<usize>> {
1642        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1643
1644        let per_layer_elems = {
1645            let input_layernorm = cfg.hidden_size;
1646            let post_attention_layernorm = cfg.hidden_size;
1647
1648            let size_in = cfg.hidden_size;
1649            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1650            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1651            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1652            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1653            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1654            let o_proj = size_q * size_in / weight_pack_factor;
1655
1656            let h_size = cfg.hidden_size;
1657            let i_size = cfg.intermediate_size;
1658            let gate_proj = h_size * i_size / weight_pack_factor;
1659            let up_proj = h_size * i_size / weight_pack_factor;
1660            let down_proj = i_size * h_size / weight_pack_factor;
1661
1662            input_layernorm
1663                + post_attention_layernorm
1664                + q_proj
1665                + k_proj
1666                + v_proj
1667                + o_proj
1668                + gate_proj
1669                + up_proj
1670                + down_proj
1671        };
1672        Ok(vec![
1673            per_layer_elems * dtype.size_in_bytes();
1674            cfg.num_hidden_layers
1675        ])
1676    }
1677
1678    fn num_layers(&self, config: &str) -> Result<usize> {
1679        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1680
1681        Ok(cfg.num_hidden_layers)
1682    }
1683
1684    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1685        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1686
1687        let cfg = ModelConfigMetadata {
1688            max_seq_len: cfg.max_position_embeddings,
1689            num_layers: cfg.num_hidden_layers,
1690            hidden_size: cfg.hidden_size,
1691            num_kv_heads: cfg.num_key_value_heads,
1692            num_attn_heads: cfg.num_attention_heads,
1693            sliding_window: cfg.sliding_window,
1694            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1695            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1696        };
1697
1698        Ok(Box::new(cfg))
1699    }
1700}
1701
1702// ======================== Gemma2 loader
1703
1704/// [`NormalLoader`] for a Gemma2 model.
1705///
1706/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1707pub struct Gemma2Loader;
1708
1709impl NormalModelLoader for Gemma2Loader {
1710    fn load(
1711        &self,
1712        config: &str,
1713        vb: ShardedVarBuilder,
1714        normal_loading_metadata: NormalLoadingMetadata,
1715        attention_mechanism: AttentionImplementation,
1716    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1717        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1718
1719        Ok(Box::new(models::gemma2::Model::new(
1720            &cfg,
1721            vb,
1722            self.is_gptx(config)?,
1723            normal_loading_metadata,
1724            attention_mechanism,
1725        )?))
1726    }
1727    fn load_xlora(
1728        &self,
1729        config: &str,
1730        vb: ShardedVarBuilder,
1731        lora_config: &[((String, String), LoraConfig)],
1732        xlora_config: Option<XLoraConfig>,
1733        xlora_ordering: Ordering,
1734        normal_loading_metadata: NormalLoadingMetadata,
1735        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1736    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1737        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1738
1739        Ok(Box::new(xlora_models::XLoraGemma2::new(
1740            &cfg,
1741            vb,
1742            lora_config,
1743            xlora_config,
1744            xlora_ordering,
1745            self.is_gptx(config)?,
1746            normal_loading_metadata,
1747            preload_adapters,
1748        )?))
1749    }
1750    fn is_gptx(&self, _: &str) -> Result<bool> {
1751        Ok(true)
1752    }
1753    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1754        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1755
1756        Ok(Box::new(cfg))
1757    }
1758}
1759
1760impl IsqModelLoader for Gemma2Loader {
1761    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1762        Ok(vec![
1763            Regex::new(r"lm_head\.(weight|bias)$")?,
1764            // Attention
1765            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1766            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1767            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1768            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1769            // MLP
1770            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1771            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1772            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1773        ])
1774    }
1775    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1776        self.isq_layer_regexes(config)
1777    }
1778}
1779
1780impl DeviceMappedModelLoader for Gemma2Loader {
1781    fn mapped_max_act_size_elems(
1782        &self,
1783        config: &str,
1784        params: &AutoDeviceMapParams,
1785        prompt_chunksize: usize,
1786    ) -> Result<usize> {
1787        let AutoDeviceMapParams::Text {
1788            max_seq_len: _,
1789            max_batch_size,
1790        } = params
1791        else {
1792            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1793        };
1794
1795        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1796
1797        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1798    }
1799    fn non_mapped_max_act_size_elems(
1800        &self,
1801        _config: &str,
1802        _params: &AutoDeviceMapParams,
1803    ) -> Result<usize> {
1804        Ok(0)
1805    }
1806
1807    fn non_mapped_size_in_bytes(
1808        &self,
1809        config: &str,
1810        dtype: DType,
1811        weight_pack_factor: usize,
1812    ) -> Result<usize> {
1813        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1814
1815        let elems = {
1816            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1817            let lm_head = if !cfg.tie_word_embeddings {
1818                cfg.hidden_size * cfg.vocab_size
1819            } else {
1820                0
1821            };
1822            let norm = cfg.hidden_size;
1823            embed_tokens + lm_head + norm
1824        };
1825        Ok(elems * dtype.size_in_bytes())
1826    }
1827
1828    fn layer_sizes_in_bytes(
1829        &self,
1830        config: &str,
1831        dtype: DType,
1832        weight_pack_factor: usize,
1833    ) -> Result<Vec<usize>> {
1834        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1835
1836        let per_layer_elems = {
1837            let input_layernorm = cfg.hidden_size;
1838            let post_attention_layernorm = cfg.hidden_size;
1839
1840            let size_in = cfg.hidden_size;
1841            let size_q = cfg.head_dim * cfg.num_attention_heads;
1842            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1843            let q_proj =
1844                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1845            let k_proj =
1846                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1847            let v_proj =
1848                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1849            let o_proj =
1850                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1851
1852            let h_size = cfg.hidden_size;
1853            let i_size = cfg.intermediate_size;
1854            let gate_proj = h_size * i_size / weight_pack_factor;
1855            let up_proj = h_size * i_size / weight_pack_factor;
1856            let down_proj = i_size * h_size / weight_pack_factor;
1857
1858            input_layernorm
1859                + post_attention_layernorm
1860                + q_proj
1861                + k_proj
1862                + v_proj
1863                + o_proj
1864                + gate_proj
1865                + up_proj
1866                + down_proj
1867        };
1868        Ok(vec![
1869            per_layer_elems * dtype.size_in_bytes();
1870            cfg.num_hidden_layers
1871        ])
1872    }
1873
1874    fn num_layers(&self, config: &str) -> Result<usize> {
1875        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1876
1877        Ok(cfg.num_hidden_layers)
1878    }
1879    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1880        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1881
1882        let cfg = ModelConfigMetadata {
1883            max_seq_len: cfg.max_position_embeddings,
1884            num_layers: cfg.num_hidden_layers,
1885            hidden_size: cfg.hidden_size,
1886            num_kv_heads: cfg.num_key_value_heads,
1887            num_attn_heads: cfg.num_attention_heads,
1888            sliding_window: None, // None to be more forgiving, some do not
1889            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1890            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1891        };
1892
1893        Ok(Box::new(cfg))
1894    }
1895}
1896
1897// ======================== Starcoder2 loader
1898
1899/// [`NormalLoader`] for a Starcoder2 model.
1900///
1901/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1902pub struct Starcoder2Loader;
1903
1904impl NormalModelLoader for Starcoder2Loader {
1905    fn load(
1906        &self,
1907        config: &str,
1908        vb: ShardedVarBuilder,
1909        normal_loading_metadata: NormalLoadingMetadata,
1910        attention_mechanism: AttentionImplementation,
1911    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1912        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1913
1914        Ok(Box::new(models::starcoder2::Model::new(
1915            &cfg,
1916            vb,
1917            self.is_gptx(config)?,
1918            normal_loading_metadata,
1919            attention_mechanism,
1920        )?))
1921    }
1922    fn load_xlora(
1923        &self,
1924        config: &str,
1925        vb: ShardedVarBuilder,
1926        lora_config: &[((String, String), LoraConfig)],
1927        xlora_config: Option<XLoraConfig>,
1928        xlora_ordering: Ordering,
1929        normal_loading_metadata: NormalLoadingMetadata,
1930        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1931    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1932        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1933
1934        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1935            &cfg,
1936            vb,
1937            lora_config,
1938            xlora_config,
1939            xlora_ordering,
1940            self.is_gptx(config)?,
1941            normal_loading_metadata,
1942            preload_adapters,
1943        )?))
1944    }
1945    fn is_gptx(&self, _: &str) -> Result<bool> {
1946        Ok(true)
1947    }
1948    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1949        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1950
1951        Ok(Box::new(cfg))
1952    }
1953}
1954
1955impl IsqModelLoader for Starcoder2Loader {
1956    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1957        Ok(vec![
1958            Regex::new(r"lm_head\.(weight|bias)$")?,
1959            // Attention
1960            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1961            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1962            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1963            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1964            // MLP
1965            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1966            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
1967        ])
1968    }
1969    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1970        self.isq_layer_regexes(config)
1971    }
1972}
1973
1974impl DeviceMappedModelLoader for Starcoder2Loader {
1975    fn mapped_max_act_size_elems(
1976        &self,
1977        config: &str,
1978        params: &AutoDeviceMapParams,
1979        prompt_chunksize: usize,
1980    ) -> Result<usize> {
1981        let AutoDeviceMapParams::Text {
1982            max_seq_len: _,
1983            max_batch_size,
1984        } = params
1985        else {
1986            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1987        };
1988
1989        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1990
1991        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1992    }
1993    fn non_mapped_max_act_size_elems(
1994        &self,
1995        _config: &str,
1996        _params: &AutoDeviceMapParams,
1997    ) -> Result<usize> {
1998        Ok(0)
1999    }
2000
2001    fn non_mapped_size_in_bytes(
2002        &self,
2003        config: &str,
2004        dtype: DType,
2005        weight_pack_factor: usize,
2006    ) -> Result<usize> {
2007        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2008
2009        let elems = {
2010            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2011            let lm_head = if !cfg.tie_word_embeddings {
2012                cfg.hidden_size * cfg.vocab_size
2013            } else {
2014                0
2015            };
2016            let norm = cfg.hidden_size + cfg.hidden_size;
2017            embed_tokens + lm_head + norm
2018        };
2019        Ok(elems * dtype.size_in_bytes())
2020    }
2021
2022    fn layer_sizes_in_bytes(
2023        &self,
2024        config: &str,
2025        dtype: DType,
2026        weight_pack_factor: usize,
2027    ) -> Result<Vec<usize>> {
2028        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2029
2030        let per_layer_elems = {
2031            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2032            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2033
2034            let size_in = cfg.hidden_size;
2035            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2036            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2037            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2038            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2039            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2040            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2041
2042            let h_size = cfg.hidden_size;
2043            let i_size = cfg.intermediate_size;
2044            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2045            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2046
2047            input_layernorm
2048                + post_attention_layernorm
2049                + q_proj
2050                + k_proj
2051                + v_proj
2052                + o_proj
2053                + fc1
2054                + fc2
2055        };
2056        Ok(vec![
2057            per_layer_elems * dtype.size_in_bytes();
2058            cfg.num_hidden_layers
2059        ])
2060    }
2061
2062    fn num_layers(&self, config: &str) -> Result<usize> {
2063        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2064
2065        Ok(cfg.num_hidden_layers)
2066    }
2067
2068    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2069        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071        let cfg = ModelConfigMetadata {
2072            max_seq_len: cfg.max_position_embeddings,
2073            num_layers: cfg.num_hidden_layers,
2074            hidden_size: cfg.hidden_size,
2075            num_kv_heads: cfg.num_key_value_heads,
2076            num_attn_heads: cfg.num_attention_heads,
2077            sliding_window: cfg.sliding_window,
2078            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2079            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2080        };
2081
2082        Ok(Box::new(cfg))
2083    }
2084}
2085
2086// ======================== Phi3 loader
2087
2088/// [`NormalLoader`] for a Phi 3.5 MoE model.
2089///
2090/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2091pub struct Phi3_5MoELoader;
2092
2093impl NormalModelLoader for Phi3_5MoELoader {
2094    fn load(
2095        &self,
2096        config: &str,
2097        vb: ShardedVarBuilder,
2098        normal_loading_metadata: NormalLoadingMetadata,
2099        attention_mechanism: AttentionImplementation,
2100    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2101        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2102
2103        Ok(Box::new(models::phi3_5_moe::Model::new(
2104            &cfg,
2105            vb,
2106            self.is_gptx(config)?,
2107            normal_loading_metadata,
2108            attention_mechanism,
2109        )?))
2110    }
2111    fn load_xlora(
2112        &self,
2113        config: &str,
2114        vb: ShardedVarBuilder,
2115        lora_config: &[((String, String), LoraConfig)],
2116        xlora_config: Option<XLoraConfig>,
2117        xlora_ordering: Ordering,
2118        normal_loading_metadata: NormalLoadingMetadata,
2119        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2120    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2121        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2122
2123        Ok(Box::new(xlora_models::XLoraPhi3::new(
2124            &cfg,
2125            vb,
2126            lora_config,
2127            xlora_config,
2128            xlora_ordering,
2129            self.is_gptx(config)?,
2130            normal_loading_metadata,
2131            preload_adapters,
2132        )?))
2133    }
2134    fn is_gptx(&self, _: &str) -> Result<bool> {
2135        Ok(true)
2136    }
2137    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2138        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2139
2140        Ok(Box::new(cfg))
2141    }
2142}
2143
2144impl IsqModelLoader for Phi3_5MoELoader {
2145    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2146        Ok(vec![
2147            Regex::new(r"lm_head\.(weight|bias)$")?,
2148            // Attention
2149            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2150            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2151            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2152            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2153            // MLP
2154            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2155            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2156            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2157        ])
2158    }
2159    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2160        self.isq_layer_regexes(config)
2161    }
2162
2163    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2164        Ok(vec![
2165            Regex::new(r"lm_head\.(weight|bias)$")?,
2166            // MLP
2167            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2168            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2169            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2170        ])
2171    }
2172    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2173        self.isq_layer_regexes_moqe(config)
2174    }
2175}
2176
2177impl DeviceMappedModelLoader for Phi3_5MoELoader {
2178    fn mapped_max_act_size_elems(
2179        &self,
2180        config: &str,
2181        params: &AutoDeviceMapParams,
2182        prompt_chunksize: usize,
2183    ) -> Result<usize> {
2184        let AutoDeviceMapParams::Text {
2185            max_seq_len: _,
2186            max_batch_size,
2187        } = params
2188        else {
2189            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2190        };
2191
2192        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2193
2194        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2195    }
2196    fn non_mapped_max_act_size_elems(
2197        &self,
2198        _config: &str,
2199        _params: &AutoDeviceMapParams,
2200    ) -> Result<usize> {
2201        Ok(0)
2202    }
2203
2204    fn non_mapped_size_in_bytes(
2205        &self,
2206        config: &str,
2207        dtype: DType,
2208        weight_pack_factor: usize,
2209    ) -> Result<usize> {
2210        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2211
2212        let elems = {
2213            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2214            let lm_head = if !cfg.tie_word_embeddings {
2215                cfg.hidden_size * cfg.vocab_size
2216            } else {
2217                0
2218            };
2219            let norm = cfg.hidden_size;
2220            embed_tokens + lm_head + norm
2221        };
2222        Ok(elems * dtype.size_in_bytes())
2223    }
2224
2225    fn layer_sizes_in_bytes(
2226        &self,
2227        config: &str,
2228        dtype: DType,
2229        weight_pack_factor: usize,
2230    ) -> Result<Vec<usize>> {
2231        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233        let per_layer_elems = {
2234            let input_layernorm = cfg.hidden_size;
2235            let post_attention_layernorm = cfg.hidden_size;
2236
2237            let size_in = cfg.hidden_size;
2238            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2239            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2240            let q_proj =
2241                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2242            let k_proj =
2243                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2244            let v_proj =
2245                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2246            let o_proj =
2247                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2248
2249            let moe_block = {
2250                let gate = cfg.hidden_size * cfg.num_local_experts;
2251                // Assume quantizing weight pack factor
2252                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2253                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2254                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2255                gate + cfg.num_local_experts * w1
2256                    + cfg.num_local_experts * w2
2257                    + cfg.num_local_experts * w3
2258            };
2259
2260            input_layernorm
2261                + post_attention_layernorm
2262                + q_proj
2263                + k_proj
2264                + v_proj
2265                + o_proj
2266                + moe_block
2267        };
2268        Ok(vec![
2269            per_layer_elems * dtype.size_in_bytes();
2270            cfg.num_hidden_layers
2271        ])
2272    }
2273
2274    fn num_layers(&self, config: &str) -> Result<usize> {
2275        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2276
2277        Ok(cfg.num_hidden_layers)
2278    }
2279
2280    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2281        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2282
2283        let cfg = ModelConfigMetadata {
2284            max_seq_len: cfg.max_position_embeddings,
2285            num_layers: cfg.num_hidden_layers,
2286            hidden_size: cfg.hidden_size,
2287            num_kv_heads: cfg.num_key_value_heads,
2288            num_attn_heads: cfg.num_attention_heads,
2289            sliding_window: cfg.sliding_window,
2290            k_head_dim: cfg.head_dim(),
2291            v_head_dim: cfg.head_dim(),
2292        };
2293
2294        Ok(Box::new(cfg))
2295    }
2296}
2297
2298/// [`NormalLoader`] for a DeepSeekV2 model.
2299///
2300/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2301pub struct DeepSeekV2Loader;
2302
2303impl NormalModelLoader for DeepSeekV2Loader {
2304    fn load(
2305        &self,
2306        config: &str,
2307        vb: ShardedVarBuilder,
2308        normal_loading_metadata: NormalLoadingMetadata,
2309        attention_mechanism: AttentionImplementation,
2310    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2311        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2312
2313        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2314            &cfg,
2315            vb,
2316            self.is_gptx(config)?,
2317            normal_loading_metadata,
2318            attention_mechanism,
2319        )?))
2320    }
2321    fn load_xlora(
2322        &self,
2323        _config: &str,
2324        _vb: ShardedVarBuilder,
2325        _lora_config: &[((String, String), LoraConfig)],
2326        _xlora_config: Option<XLoraConfig>,
2327        _xlora_ordering: Ordering,
2328        _normal_loading_metadata: NormalLoadingMetadata,
2329        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2330    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2331        todo!()
2332    }
2333    fn is_gptx(&self, _: &str) -> Result<bool> {
2334        Ok(true)
2335    }
2336    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2337        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2338        Ok(Box::new(cfg))
2339    }
2340}
2341
2342impl IsqModelLoader for DeepSeekV2Loader {
2343    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2344        let mut data = vec![
2345            Regex::new(r"lm_head\.(weight|bias)$")?,
2346            // Attention
2347            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2348            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2349            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2350        ];
2351        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2352        if cfg.q_lora_rank.is_some() {
2353            data.extend(vec![
2354                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2355                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2356            ]);
2357        } else {
2358            data.push(Regex::new(
2359                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2360            )?);
2361        }
2362        for layer_idx in 0..cfg.num_hidden_layers {
2363            if cfg.n_routed_experts.is_some()
2364                && layer_idx >= cfg.first_k_dense_replace
2365                && layer_idx % cfg.moe_layer_freq == 0
2366            {
2367                for i in 0..cfg.n_routed_experts.unwrap() {
2368                    data.extend(vec![
2369                        Regex::new(&format!(
2370                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2371                        ))?,
2372                        Regex::new(&format!(
2373                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2374                        ))?,
2375                        Regex::new(&format!(
2376                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2377                        ))?,
2378                    ]);
2379                }
2380                if cfg.n_shared_experts.is_some() {
2381                    data.extend(vec![
2382                        Regex::new(&format!(
2383                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2384                        ))?,
2385                        Regex::new(&format!(
2386                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2387                        ))?,
2388                        Regex::new(&format!(
2389                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2390                        ))?,
2391                    ]);
2392                }
2393            } else {
2394                data.extend(vec![
2395                    Regex::new(&format!(
2396                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2397                    ))?,
2398                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2399                    Regex::new(&format!(
2400                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2401                    ))?,
2402                ]);
2403            };
2404        }
2405        Ok(data)
2406    }
2407    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2408        self.isq_layer_regexes(config)
2409    }
2410
2411    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2412        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2413        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2414        for layer_idx in 0..cfg.num_hidden_layers {
2415            if cfg.n_routed_experts.is_some()
2416                && layer_idx >= cfg.first_k_dense_replace
2417                && layer_idx % cfg.moe_layer_freq == 0
2418            {
2419                for i in 0..cfg.n_routed_experts.unwrap() {
2420                    data.extend(vec![
2421                        Regex::new(&format!(
2422                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2423                        ))?,
2424                        Regex::new(&format!(
2425                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2426                        ))?,
2427                        Regex::new(&format!(
2428                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2429                        ))?,
2430                    ]);
2431                }
2432                if cfg.n_shared_experts.is_some() {
2433                    data.extend(vec![
2434                        Regex::new(&format!(
2435                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2436                        ))?,
2437                        Regex::new(&format!(
2438                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2439                        ))?,
2440                        Regex::new(&format!(
2441                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2442                        ))?,
2443                    ]);
2444                }
2445            } else {
2446                data.extend(vec![
2447                    Regex::new(&format!(
2448                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2449                    ))?,
2450                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2451                    Regex::new(&format!(
2452                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2453                    ))?,
2454                ]);
2455            };
2456        }
2457        Ok(data)
2458    }
2459    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2460        self.isq_layer_regexes_moqe(config)
2461    }
2462}
2463
2464impl DeviceMappedModelLoader for DeepSeekV2Loader {
2465    fn mapped_max_act_size_elems(
2466        &self,
2467        config: &str,
2468        params: &AutoDeviceMapParams,
2469        prompt_chunksize: usize,
2470    ) -> Result<usize> {
2471        let AutoDeviceMapParams::Text {
2472            max_seq_len: _,
2473            max_batch_size,
2474        } = params
2475        else {
2476            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2477        };
2478
2479        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2480
2481        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2482    }
2483    fn non_mapped_max_act_size_elems(
2484        &self,
2485        _config: &str,
2486        _params: &AutoDeviceMapParams,
2487    ) -> Result<usize> {
2488        Ok(0)
2489    }
2490
2491    fn non_mapped_size_in_bytes(
2492        &self,
2493        config: &str,
2494        dtype: DType,
2495        weight_pack_factor: usize,
2496    ) -> Result<usize> {
2497        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2498        let elems = {
2499            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2500            let lm_head = if !cfg.tie_word_embeddings {
2501                cfg.hidden_size * cfg.vocab_size
2502            } else {
2503                0
2504            };
2505            let norm = cfg.hidden_size;
2506            embed_tokens + lm_head + norm
2507        };
2508        Ok(elems * dtype.size_in_bytes())
2509    }
2510
2511    fn layer_sizes_in_bytes(
2512        &self,
2513        config: &str,
2514        dtype: DType,
2515        weight_pack_factor: usize,
2516    ) -> Result<Vec<usize>> {
2517        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2518        let mut per_layer_elems = Vec::new();
2519
2520        for layer_idx in 0..cfg.num_hidden_layers {
2521            let input_layernorm = cfg.hidden_size;
2522            let post_attention_layernorm = cfg.hidden_size;
2523
2524            let q_proj = match cfg.q_lora_rank {
2525                Some(lora_rank) => {
2526                    let a = cfg.hidden_size * lora_rank;
2527                    let norm = lora_rank;
2528                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2529                    a + norm + b
2530                }
2531                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2532            };
2533            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2534                / weight_pack_factor
2535                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2536            let kv_a_layernorm = cfg.kv_lora_rank;
2537            let kv_b_proj = cfg.kv_lora_rank
2538                * cfg.num_attention_heads
2539                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2540                / weight_pack_factor;
2541            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2542                / weight_pack_factor
2543                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2544
2545            let moe_block = {
2546                let mut sum = 0;
2547                if cfg.n_routed_experts.is_some()
2548                    && layer_idx >= cfg.first_k_dense_replace
2549                    && layer_idx % cfg.moe_layer_freq == 0
2550                {
2551                    let h_size = cfg.hidden_size;
2552                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2553                        * cfg.n_routed_experts.unwrap();
2554                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2555                        * cfg.n_routed_experts.unwrap();
2556                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2557                        * cfg.n_routed_experts.unwrap();
2558                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2559                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2560                            / weight_pack_factor;
2561                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2562                            / weight_pack_factor;
2563                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2564                            / weight_pack_factor;
2565                        gate_proj + up_proj + down_proj
2566                    } else {
2567                        0
2568                    };
2569                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2570                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2571                } else {
2572                    let h_size = cfg.hidden_size;
2573                    let i_size = cfg.intermediate_size;
2574                    let gate_proj = h_size * i_size / weight_pack_factor;
2575                    let up_proj = h_size * i_size / weight_pack_factor;
2576                    let down_proj = i_size * h_size / weight_pack_factor;
2577                    sum += gate_proj + up_proj + down_proj;
2578                }
2579                sum
2580            };
2581
2582            per_layer_elems.push(
2583                input_layernorm
2584                    + post_attention_layernorm
2585                    + q_proj
2586                    + kv_a_layernorm
2587                    + kv_a_proj_with_mqa
2588                    + kv_b_proj
2589                    + o_proj
2590                    + moe_block,
2591            );
2592        }
2593
2594        Ok(per_layer_elems
2595            .into_iter()
2596            .map(|x| x * dtype.size_in_bytes())
2597            .collect())
2598    }
2599
2600    fn num_layers(&self, config: &str) -> Result<usize> {
2601        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2602        Ok(cfg.num_hidden_layers)
2603    }
2604
2605    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2606        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2607
2608        let cfg = ModelConfigMetadata {
2609            max_seq_len: cfg.max_position_embeddings,
2610            num_layers: cfg.num_hidden_layers,
2611            hidden_size: cfg.hidden_size,
2612            num_kv_heads: cfg.num_attention_heads,
2613            num_attn_heads: cfg.num_attention_heads,
2614            sliding_window: None,
2615            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2616            v_head_dim: cfg.v_head_dim,
2617        };
2618
2619        Ok(Box::new(cfg))
2620    }
2621}
2622
2623/// [`NormalLoader`] for a DeepSeekV3 model.
2624///
2625/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2626pub struct DeepSeekV3Loader;
2627
2628impl NormalModelLoader for DeepSeekV3Loader {
2629    fn load(
2630        &self,
2631        config: &str,
2632        vb: ShardedVarBuilder,
2633        normal_loading_metadata: NormalLoadingMetadata,
2634        attention_mechanism: AttentionImplementation,
2635    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2636        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2637        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2638            &cfg,
2639            vb,
2640            self.is_gptx(config)?,
2641            normal_loading_metadata,
2642            attention_mechanism,
2643        )?))
2644    }
2645    fn load_xlora(
2646        &self,
2647        _config: &str,
2648        _vb: ShardedVarBuilder,
2649        _lora_config: &[((String, String), LoraConfig)],
2650        _xlora_config: Option<XLoraConfig>,
2651        _xlora_ordering: Ordering,
2652        _normal_loading_metadata: NormalLoadingMetadata,
2653        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2654    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2655        todo!()
2656    }
2657    fn is_gptx(&self, _: &str) -> Result<bool> {
2658        Ok(true)
2659    }
2660    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2661        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2662        Ok(Box::new(cfg))
2663    }
2664}
2665
2666impl IsqModelLoader for DeepSeekV3Loader {
2667    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2668        let mut data = vec![
2669            Regex::new(r"lm_head\.(weight|bias)$")?,
2670            // Attention
2671            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2672            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2673            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2674        ];
2675        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2676        if cfg.q_lora_rank.is_some() {
2677            data.extend(vec![
2678                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2679                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2680            ]);
2681        } else {
2682            data.push(Regex::new(
2683                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2684            )?);
2685        }
2686        for layer_idx in 0..cfg.num_hidden_layers {
2687            if cfg.n_routed_experts.is_some()
2688                && layer_idx >= cfg.first_k_dense_replace
2689                && layer_idx % cfg.moe_layer_freq == 0
2690            {
2691                for i in 0..cfg.n_routed_experts.unwrap() {
2692                    data.extend(vec![
2693                        Regex::new(&format!(
2694                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2695                        ))?,
2696                        Regex::new(&format!(
2697                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2698                        ))?,
2699                        Regex::new(&format!(
2700                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2701                        ))?,
2702                    ]);
2703                }
2704                if cfg.n_shared_experts.is_some() {
2705                    data.extend(vec![
2706                        Regex::new(&format!(
2707                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2708                        ))?,
2709                        Regex::new(&format!(
2710                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2711                        ))?,
2712                        Regex::new(&format!(
2713                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2714                        ))?,
2715                    ]);
2716                }
2717            } else {
2718                data.extend(vec![
2719                    Regex::new(&format!(
2720                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2721                    ))?,
2722                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2723                    Regex::new(&format!(
2724                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2725                    ))?,
2726                ]);
2727            };
2728        }
2729        Ok(data)
2730    }
2731    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2732        self.isq_layer_regexes(config)
2733    }
2734
2735    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2736        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2737        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2738        for layer_idx in 0..cfg.num_hidden_layers {
2739            if cfg.n_routed_experts.is_some()
2740                && layer_idx >= cfg.first_k_dense_replace
2741                && layer_idx % cfg.moe_layer_freq == 0
2742            {
2743                for i in 0..cfg.n_routed_experts.unwrap() {
2744                    data.extend(vec![
2745                        Regex::new(&format!(
2746                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2747                        ))?,
2748                        Regex::new(&format!(
2749                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2750                        ))?,
2751                        Regex::new(&format!(
2752                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2753                        ))?,
2754                    ]);
2755                }
2756                if cfg.n_shared_experts.is_some() {
2757                    data.extend(vec![
2758                        Regex::new(&format!(
2759                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2760                        ))?,
2761                        Regex::new(&format!(
2762                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2763                        ))?,
2764                        Regex::new(&format!(
2765                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2766                        ))?,
2767                    ]);
2768                }
2769            } else {
2770                data.extend(vec![
2771                    Regex::new(&format!(
2772                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2773                    ))?,
2774                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2775                    Regex::new(&format!(
2776                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2777                    ))?,
2778                ]);
2779            };
2780        }
2781        Ok(data)
2782    }
2783    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2784        self.isq_layer_regexes_moqe(config)
2785    }
2786}
2787
2788impl DeviceMappedModelLoader for DeepSeekV3Loader {
2789    fn mapped_max_act_size_elems(
2790        &self,
2791        config: &str,
2792        params: &AutoDeviceMapParams,
2793        prompt_chunksize: usize,
2794    ) -> Result<usize> {
2795        let AutoDeviceMapParams::Text {
2796            max_seq_len: _,
2797            max_batch_size,
2798        } = params
2799        else {
2800            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2801        };
2802
2803        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2804
2805        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2806    }
2807    fn non_mapped_max_act_size_elems(
2808        &self,
2809        _config: &str,
2810        _params: &AutoDeviceMapParams,
2811    ) -> Result<usize> {
2812        Ok(0)
2813    }
2814
2815    fn non_mapped_size_in_bytes(
2816        &self,
2817        config: &str,
2818        dtype: DType,
2819        weight_pack_factor: usize,
2820    ) -> Result<usize> {
2821        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2822        let elems = {
2823            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2824            let lm_head = if !cfg.tie_word_embeddings {
2825                cfg.hidden_size * cfg.vocab_size
2826            } else {
2827                0
2828            };
2829            let norm = cfg.hidden_size;
2830            embed_tokens + lm_head + norm
2831        };
2832        Ok(elems * dtype.size_in_bytes())
2833    }
2834
2835    fn layer_sizes_in_bytes(
2836        &self,
2837        config: &str,
2838        dtype: DType,
2839        weight_pack_factor: usize,
2840    ) -> Result<Vec<usize>> {
2841        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2842        let mut per_layer_elems = Vec::new();
2843
2844        for layer_idx in 0..cfg.num_hidden_layers {
2845            let input_layernorm = cfg.hidden_size;
2846            let post_attention_layernorm = cfg.hidden_size;
2847
2848            let q_proj = match cfg.q_lora_rank {
2849                Some(lora_rank) => {
2850                    let a = cfg.hidden_size * lora_rank;
2851                    let norm = lora_rank;
2852                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2853                    a + norm + b
2854                }
2855                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2856            };
2857            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2858                / weight_pack_factor
2859                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2860            let kv_a_layernorm = cfg.kv_lora_rank;
2861            let kv_b_proj = cfg.kv_lora_rank
2862                * cfg.num_attention_heads
2863                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2864                / weight_pack_factor;
2865            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2866                / weight_pack_factor
2867                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2868
2869            let moe_block = {
2870                let mut sum = 0;
2871                if cfg.n_routed_experts.is_some()
2872                    && layer_idx >= cfg.first_k_dense_replace
2873                    && layer_idx % cfg.moe_layer_freq == 0
2874                {
2875                    let h_size = cfg.hidden_size;
2876                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2877                        * cfg.n_routed_experts.unwrap();
2878                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2879                        * cfg.n_routed_experts.unwrap();
2880                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2881                        * cfg.n_routed_experts.unwrap();
2882                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2883                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2884                            / weight_pack_factor;
2885                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2886                            / weight_pack_factor;
2887                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2888                            / weight_pack_factor;
2889                        gate_proj + up_proj + down_proj
2890                    } else {
2891                        0
2892                    };
2893                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2894                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2895                } else {
2896                    let h_size = cfg.hidden_size;
2897                    let i_size = cfg.intermediate_size;
2898                    let gate_proj = h_size * i_size / weight_pack_factor;
2899                    let up_proj = h_size * i_size / weight_pack_factor;
2900                    let down_proj = i_size * h_size / weight_pack_factor;
2901                    sum += gate_proj + up_proj + down_proj;
2902                }
2903                sum
2904            };
2905
2906            per_layer_elems.push(
2907                input_layernorm
2908                    + post_attention_layernorm
2909                    + q_proj
2910                    + kv_a_layernorm
2911                    + kv_a_proj_with_mqa
2912                    + kv_b_proj
2913                    + o_proj
2914                    + moe_block,
2915            );
2916        }
2917
2918        Ok(per_layer_elems
2919            .into_iter()
2920            .map(|x| x * dtype.size_in_bytes())
2921            .collect())
2922    }
2923
2924    fn num_layers(&self, config: &str) -> Result<usize> {
2925        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2926        Ok(cfg.num_hidden_layers)
2927    }
2928
2929    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2930        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2931
2932        let cfg = ModelConfigMetadata {
2933            max_seq_len: cfg.max_position_embeddings,
2934            num_layers: cfg.num_hidden_layers,
2935            hidden_size: cfg.hidden_size,
2936            num_kv_heads: cfg.num_attention_heads,
2937            num_attn_heads: cfg.num_attention_heads,
2938            sliding_window: None,
2939            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2940            v_head_dim: cfg.v_head_dim,
2941        };
2942
2943        Ok(Box::new(cfg))
2944    }
2945}
2946
2947/// [`NormalLoader`] for a Qwen 3 model.
2948///
2949/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2950pub struct Qwen3Loader;
2951
2952impl NormalModelLoader for Qwen3Loader {
2953    fn load(
2954        &self,
2955        config: &str,
2956        vb: ShardedVarBuilder,
2957        normal_loading_metadata: NormalLoadingMetadata,
2958        attention_mechanism: AttentionImplementation,
2959    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2960        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2961
2962        Ok(Box::new(models::qwen3::Model::new(
2963            &cfg,
2964            vb,
2965            self.is_gptx(config)?,
2966            normal_loading_metadata,
2967            attention_mechanism,
2968        )?))
2969    }
2970    fn load_xlora(
2971        &self,
2972        _config: &str,
2973        _vb: ShardedVarBuilder,
2974        _lora_config: &[((String, String), LoraConfig)],
2975        _xlora_config: Option<XLoraConfig>,
2976        _xlora_ordering: Ordering,
2977        _normal_loading_metadata: NormalLoadingMetadata,
2978        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2979    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2980        todo!()
2981    }
2982    fn is_gptx(&self, _: &str) -> Result<bool> {
2983        Ok(true)
2984    }
2985    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2986        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2987
2988        Ok(Box::new(cfg))
2989    }
2990}
2991
2992impl IsqModelLoader for Qwen3Loader {
2993    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2994        Ok(vec![
2995            Regex::new(r"lm_head\.(weight|bias)$")?,
2996            // Attention
2997            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2998            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2999            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3000            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3001            // MLP
3002            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3003            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3004            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3005        ])
3006    }
3007    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3008        self.isq_layer_regexes(config)
3009    }
3010}
3011
3012impl DeviceMappedModelLoader for Qwen3Loader {
3013    fn mapped_max_act_size_elems(
3014        &self,
3015        config: &str,
3016        params: &AutoDeviceMapParams,
3017        prompt_chunksize: usize,
3018    ) -> Result<usize> {
3019        let AutoDeviceMapParams::Text {
3020            max_seq_len: _,
3021            max_batch_size,
3022        } = params
3023        else {
3024            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3025        };
3026
3027        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3028
3029        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3030    }
3031    fn non_mapped_max_act_size_elems(
3032        &self,
3033        _config: &str,
3034        _params: &AutoDeviceMapParams,
3035    ) -> Result<usize> {
3036        Ok(0)
3037    }
3038
3039    fn non_mapped_size_in_bytes(
3040        &self,
3041        config: &str,
3042        dtype: DType,
3043        weight_pack_factor: usize,
3044    ) -> Result<usize> {
3045        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3046        let elems = {
3047            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3048            let lm_head = if !cfg.tie_word_embeddings {
3049                cfg.hidden_size * cfg.vocab_size
3050            } else {
3051                0
3052            };
3053            let norm = cfg.hidden_size;
3054            embed_tokens + lm_head + norm
3055        };
3056        Ok(elems * dtype.size_in_bytes())
3057    }
3058
3059    fn layer_sizes_in_bytes(
3060        &self,
3061        config: &str,
3062        dtype: DType,
3063        weight_pack_factor: usize,
3064    ) -> Result<Vec<usize>> {
3065        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3066        let per_layer_elems = {
3067            let input_layernorm = cfg.hidden_size;
3068            let post_attention_layernorm = cfg.hidden_size;
3069
3070            let size_in = cfg.hidden_size;
3071            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3072            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3073            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3074            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3075            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3076            let o_proj = size_q * size_in / weight_pack_factor;
3077
3078            let h_size = cfg.hidden_size;
3079            let i_size = cfg.intermediate_size;
3080            let gate_proj = h_size * i_size / weight_pack_factor;
3081            let up_proj = h_size * i_size / weight_pack_factor;
3082            let down_proj = i_size * h_size / weight_pack_factor;
3083
3084            let q_norm = cfg.head_dim();
3085            let k_norm = cfg.head_dim();
3086
3087            input_layernorm
3088                + post_attention_layernorm
3089                + q_proj
3090                + k_proj
3091                + v_proj
3092                + o_proj
3093                + gate_proj
3094                + up_proj
3095                + down_proj
3096                + q_norm
3097                + k_norm
3098        };
3099        Ok(vec![
3100            per_layer_elems * dtype.size_in_bytes();
3101            cfg.num_hidden_layers
3102        ])
3103    }
3104
3105    fn num_layers(&self, config: &str) -> Result<usize> {
3106        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3107        Ok(cfg.num_hidden_layers)
3108    }
3109
3110    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3111        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3112
3113        let cfg = ModelConfigMetadata {
3114            max_seq_len: cfg.max_position_embeddings,
3115            num_layers: cfg.num_hidden_layers,
3116            hidden_size: cfg.hidden_size,
3117            num_kv_heads: cfg.num_key_value_heads,
3118            num_attn_heads: cfg.num_attention_heads,
3119            sliding_window: cfg.sliding_window,
3120            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3121            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3122        };
3123
3124        Ok(Box::new(cfg))
3125    }
3126}
3127
3128/// [`NormalLoader`] for a Qwen 3 MoE model.
3129///
3130/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3131pub struct Qwen3MoELoader;
3132
3133impl NormalModelLoader for Qwen3MoELoader {
3134    fn load(
3135        &self,
3136        config: &str,
3137        vb: ShardedVarBuilder,
3138        normal_loading_metadata: NormalLoadingMetadata,
3139        attention_mechanism: AttentionImplementation,
3140    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3141        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3142
3143        Ok(Box::new(models::qwen3_moe::Model::new(
3144            &cfg,
3145            vb,
3146            self.is_gptx(config)?,
3147            normal_loading_metadata,
3148            attention_mechanism,
3149        )?))
3150    }
3151    fn load_xlora(
3152        &self,
3153        _config: &str,
3154        _vb: ShardedVarBuilder,
3155        _lora_config: &[((String, String), LoraConfig)],
3156        _xlora_config: Option<XLoraConfig>,
3157        _xlora_ordering: Ordering,
3158        _normal_loading_metadata: NormalLoadingMetadata,
3159        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3160    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3161        todo!()
3162    }
3163    fn is_gptx(&self, _: &str) -> Result<bool> {
3164        Ok(true)
3165    }
3166    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3167        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3168
3169        Ok(Box::new(cfg))
3170    }
3171}
3172
3173impl IsqModelLoader for Qwen3MoELoader {
3174    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3175        Ok(vec![
3176            Regex::new(r"lm_head\.(weight|bias)$")?,
3177            // Attention
3178            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3179            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3180            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3181            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3182            // MLP
3183            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3184            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3185            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3186            // MLP MoE
3187            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3188            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3189            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3190        ])
3191    }
3192    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3193        self.isq_layer_regexes(config)
3194    }
3195    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3196        self.isq_layer_regexes_moqe(config)
3197    }
3198}
3199
3200impl DeviceMappedModelLoader for Qwen3MoELoader {
3201    fn mapped_max_act_size_elems(
3202        &self,
3203        config: &str,
3204        params: &AutoDeviceMapParams,
3205        prompt_chunksize: usize,
3206    ) -> Result<usize> {
3207        let AutoDeviceMapParams::Text {
3208            max_seq_len: _,
3209            max_batch_size,
3210        } = params
3211        else {
3212            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3213        };
3214
3215        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3216
3217        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3218    }
3219    fn non_mapped_max_act_size_elems(
3220        &self,
3221        _config: &str,
3222        _params: &AutoDeviceMapParams,
3223    ) -> Result<usize> {
3224        Ok(0)
3225    }
3226
3227    fn non_mapped_size_in_bytes(
3228        &self,
3229        config: &str,
3230        dtype: DType,
3231        weight_pack_factor: usize,
3232    ) -> Result<usize> {
3233        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3234        let elems = {
3235            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3236            let lm_head = if !cfg.tie_word_embeddings {
3237                cfg.hidden_size * cfg.vocab_size
3238            } else {
3239                0
3240            };
3241            let norm = cfg.hidden_size;
3242            embed_tokens + lm_head + norm
3243        };
3244        Ok(elems * dtype.size_in_bytes())
3245    }
3246
3247    fn layer_sizes_in_bytes(
3248        &self,
3249        config: &str,
3250        dtype: DType,
3251        weight_pack_factor: usize,
3252    ) -> Result<Vec<usize>> {
3253        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3254
3255        let mut layer_sizes_in_bytes = Vec::new();
3256        for layer_idx in 0..cfg.num_hidden_layers {
3257            let input_layernorm = cfg.hidden_size;
3258            let post_attention_layernorm = cfg.hidden_size;
3259
3260            let size_in = cfg.hidden_size;
3261            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3262            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3263            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3264            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3265            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3266            let o_proj = size_q * size_in / weight_pack_factor;
3267
3268            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3269                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3270            {
3271                let expert_size = {
3272                    let h_size = cfg.hidden_size;
3273                    let i_size = cfg.moe_intermediate_size;
3274                    let gate_proj = h_size * i_size / weight_pack_factor;
3275                    let up_proj = h_size * i_size / weight_pack_factor;
3276                    let down_proj = i_size * h_size / weight_pack_factor;
3277                    gate_proj + up_proj + down_proj
3278                };
3279                expert_size * cfg.num_experts
3280            } else {
3281                let h_size = cfg.hidden_size;
3282                let i_size = cfg.intermediate_size;
3283                let gate_proj = h_size * i_size / weight_pack_factor;
3284                let up_proj = h_size * i_size / weight_pack_factor;
3285                let down_proj = i_size * h_size / weight_pack_factor;
3286                gate_proj + up_proj + down_proj
3287            };
3288
3289            let q_norm = cfg.head_dim();
3290            let k_norm = cfg.head_dim();
3291
3292            let size_elems = input_layernorm
3293                + post_attention_layernorm
3294                + q_proj
3295                + k_proj
3296                + v_proj
3297                + o_proj
3298                + mlp_size
3299                + q_norm
3300                + k_norm;
3301
3302            let size_in_bytes = size_elems * dtype.size_in_bytes();
3303            layer_sizes_in_bytes.push(size_in_bytes);
3304        }
3305
3306        Ok(layer_sizes_in_bytes)
3307    }
3308
3309    fn num_layers(&self, config: &str) -> Result<usize> {
3310        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3311        Ok(cfg.num_hidden_layers)
3312    }
3313
3314    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3315        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3316
3317        let cfg = ModelConfigMetadata {
3318            max_seq_len: cfg.max_position_embeddings,
3319            num_layers: cfg.num_hidden_layers,
3320            hidden_size: cfg.hidden_size,
3321            num_kv_heads: cfg.num_key_value_heads,
3322            num_attn_heads: cfg.num_attention_heads,
3323            sliding_window: cfg.sliding_window,
3324            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3325            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3326        };
3327
3328        Ok(Box::new(cfg))
3329    }
3330}