mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    amoe::AnyMoeBaseModelMixin,
10    device_map::DeviceMapper,
11    layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig},
12    lora::{LoraConfig, Ordering},
13    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
14    pipeline::{
15        isq::IsqModelLoader,
16        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17        EitherCache, IsqModel,
18    },
19    serde_default_fn,
20    utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25
26use indicatif::MultiProgress;
27use mistralrs_quant::{QuantizedConfig, ShardedVarBuilder};
28#[cfg(feature = "pyo3_macros")]
29use pyo3::pyclass;
30
31use regex::Regex;
32use serde::Deserialize;
33
34use crate::{
35    models,
36    xlora_models::{self, XLoraConfig},
37};
38
39use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
40
41pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
42    #[allow(clippy::too_many_arguments)]
43    fn forward(
44        &self,
45        input_ids: &Tensor,
46        seqlen_offsets: &[usize],
47        context_lens: Vec<(usize, usize)>,
48        position_ids: Vec<usize>,
49        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
50        flash_params: &FlashParams,
51    ) -> candle_core::Result<Tensor>;
52    #[allow(clippy::too_many_arguments)]
53    fn xlora_forward(
54        &self,
55        input_ids: &Tensor,
56        input_ids_full: &Tensor,
57        seqlen_offsets: &[usize],
58        seqlen_offsets_full: &[usize],
59        no_kv_cache: bool,
60        non_granular_state: &Option<NonGranularState>,
61        context_lens: Vec<(usize, usize)>,
62        position_ids: Vec<usize>,
63        flash_params: &FlashParams,
64        flash_params_full: &FlashParams,
65    ) -> candle_core::Result<Tensor>;
66    fn is_xlora(&self) -> bool;
67    fn device(&self) -> &Device;
68    fn cache(&self) -> &EitherCache;
69    fn cache_mut(&mut self) -> &mut EitherCache;
70    fn max_seq_len(&self) -> usize;
71    fn config(&self) -> &ModelConfigMetadata;
72}
73
74/// Metadata for loading a model with ISQ or device mapping.
75pub struct NormalLoadingMetadata {
76    // Device mapping metadata which can be used to construct a concrete device mapper
77    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
78    // Flag to check if loading in ISQ
79    pub loading_isq: bool,
80    // Device mapping target device (the one that is not the cpu)
81    pub real_device: Device,
82    // MultiProgress support for parallelized loading
83    pub multi_progress: Arc<MultiProgress>,
84}
85
86pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87    fn load(
88        &self,
89        config: &str,
90        use_flash_attn: bool,
91        vb: ShardedVarBuilder,
92        normal_loading_metadata: NormalLoadingMetadata,
93        attention_mechanism: AttentionImplementation,
94    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
95    #[allow(clippy::too_many_arguments)]
96    fn load_xlora(
97        &self,
98        config: &str,
99        use_flash_attn: bool,
100        vb: ShardedVarBuilder,
101        lora_config: &[((String, String), LoraConfig)],
102        xlora_config: Option<XLoraConfig>,
103        xlora_ordering: Ordering,
104        normal_loading_metadata: NormalLoadingMetadata,
105        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
106    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
107    fn is_gptx(&self, config: &str) -> Result<bool>;
108    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
109    fn get_device_for_tensor(
110        &self,
111        config: &str,
112        _mapper: &dyn DeviceMapper,
113        loading_isq: bool,
114    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115        if loading_isq {
116            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117        } else {
118            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119            let num_layers = self.model_config(config)?.num_layers();
120            let closure = move |name: String| {
121                if let Some(captures) = re.captures(&name) {
122                    captures
123                        .get(1)
124                        .and_then(|m| m.as_str().parse::<usize>().ok())
125                        .map(|l| l.min(num_layers))
126                        .map(DeviceForLoadTensor::Idx)
127                        .unwrap_or(DeviceForLoadTensor::Base)
128                } else {
129                    DeviceForLoadTensor::Base
130                }
131            };
132
133            Ok(Arc::new(closure))
134        }
135    }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140/// The architecture to load the normal model as.
141pub enum NormalLoaderType {
142    #[serde(rename = "mistral")]
143    Mistral,
144    #[serde(rename = "gemma")]
145    Gemma,
146    #[serde(rename = "mixtral")]
147    Mixtral,
148    #[serde(rename = "llama")]
149    Llama,
150    #[serde(rename = "phi2")]
151    Phi2,
152    #[serde(rename = "phi3")]
153    Phi3,
154    #[serde(rename = "qwen2")]
155    Qwen2,
156    #[serde(rename = "gemma2")]
157    Gemma2,
158    #[serde(rename = "starcoder2")]
159    Starcoder2,
160    #[serde(rename = "phi3.5moe")]
161    Phi3_5MoE,
162    #[serde(rename = "deepseekv2")]
163    DeepSeekV2,
164    #[serde(rename = "deepseekv3")]
165    DeepSeekV3,
166}
167
168// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
169
170impl NormalLoaderType {
171    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
172        match name {
173            "MistralForCausalLM" => Ok(Self::Mistral),
174            "MixtralForCausalLM" => Ok(Self::Mixtral),
175            "GemmaForCausalLM" => Ok(Self::Gemma),
176            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
177            "PhiForCausalLM" => Ok(Self::Phi2),
178            "Phi3ForCausalLM" => Ok(Self::Phi3),
179            "LlamaForCausalLM" => Ok(Self::Llama),
180            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
181            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
182            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
183            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
184            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
185            other => anyhow::bail!(
186                "Unsupported Huggging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
187            ),
188        }
189    }
190}
191
192impl FromStr for NormalLoaderType {
193    type Err = String;
194    fn from_str(s: &str) -> Result<Self, Self::Err> {
195        match s {
196            "mistral" => Ok(Self::Mistral),
197            "gemma" => Ok(Self::Gemma),
198            "mixtral" => Ok(Self::Mixtral),
199            "llama" => Ok(Self::Llama),
200            "phi2" => Ok(Self::Phi2),
201            "phi3" => Ok(Self::Phi3),
202            "qwen2" => Ok(Self::Qwen2),
203            "gemma2" => Ok(Self::Gemma2),
204            "starcoder2" => Ok(Self::Starcoder2),
205            "phi3.5moe" => Ok(Self::Phi3_5MoE),
206            "deepseekv2" => Ok(Self::DeepSeekV2),
207            "deepseekv3" => Ok(Self::DeepSeekV3),
208            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`.")),
209        }
210    }
211}
212
213impl Display for NormalLoaderType {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        match self {
216            Self::Gemma => write!(f, "gemma"),
217            Self::Gemma2 => write!(f, "gemma2"),
218            Self::Llama => write!(f, "llama"),
219            Self::Mistral => write!(f, "mistral"),
220            Self::Mixtral => write!(f, "mixtral"),
221            Self::Phi2 => write!(f, "phi2"),
222            Self::Phi3 => write!(f, "phi3"),
223            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
224            Self::Qwen2 => write!(f, "qwen2"),
225            Self::Starcoder2 => write!(f, "starcoder2"),
226            Self::DeepSeekV2 => write!(f, "deepseekv2"),
227            Self::DeepSeekV3 => write!(f, "deepseekv3"),
228        }
229    }
230}
231
232macro_rules! bias_if {
233    ($cond:expr, $size:expr) => {
234        if $cond {
235            $size
236        } else {
237            0
238        }
239    };
240}
241
242/// Load a model based on the Huggging Face Transformers -CausalLM model class
243pub struct AutoLoader;
244
245#[derive(Deserialize)]
246struct AutoLoaderConfig {
247    architectures: Vec<String>,
248}
249
250impl AutoLoader {
251    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
252        let auto_cfg: AutoLoaderConfig = serde_json::from_str(config)?;
253        if auto_cfg.architectures.len() != 1 {
254            anyhow::bail!("Expected to have one name for `architectures` config field.")
255        }
256
257        let name = &auto_cfg.architectures[0];
258
259        let tp = NormalLoaderType::from_causal_lm_name(name)?;
260
261        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
262
263        match tp {
264            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
265            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
266            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
267            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
268            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
269            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
270            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
271            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
272            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
273            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
274            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
275            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
276        }
277    }
278}
279
280impl NormalModelLoader for AutoLoader {
281    fn load(
282        &self,
283        config: &str,
284        use_flash_attn: bool,
285        vb: ShardedVarBuilder,
286        normal_loading_metadata: NormalLoadingMetadata,
287        attention_mechanism: AttentionImplementation,
288    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
289        Self::get_loader(config)?.load(
290            config,
291            use_flash_attn,
292            vb,
293            normal_loading_metadata,
294            attention_mechanism,
295        )
296    }
297    fn load_xlora(
298        &self,
299        config: &str,
300        use_flash_attn: bool,
301        vb: ShardedVarBuilder,
302        lora_config: &[((String, String), LoraConfig)],
303        xlora_config: Option<XLoraConfig>,
304        xlora_ordering: Ordering,
305        normal_loading_metadata: NormalLoadingMetadata,
306        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
307    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
308        Self::get_loader(config)?.load_xlora(
309            config,
310            use_flash_attn,
311            vb,
312            lora_config,
313            xlora_config,
314            xlora_ordering,
315            normal_loading_metadata,
316            preload_adapters,
317        )
318    }
319    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
320        Self::get_loader(config)?.get_config_repr(config, use_flash_attn)
321    }
322    fn is_gptx(&self, config: &str) -> Result<bool> {
323        Self::get_loader(config)?.is_gptx(config)
324    }
325}
326
327impl IsqModelLoader for AutoLoader {
328    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
329        Self::get_loader(config)?.isq_layer_regexes(config)
330    }
331}
332
333impl DeviceMappedModelLoader for AutoLoader {
334    fn non_mapped_size_in_bytes(
335        &self,
336        config: &str,
337        dtype: DType,
338        weight_pack_factor: usize,
339    ) -> Result<usize> {
340        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
341    }
342    fn num_layers(&self, config: &str) -> Result<usize> {
343        Self::get_loader(config)?.num_layers(config)
344    }
345    fn layer_sizes_in_bytes(
346        &self,
347        config: &str,
348        dtype: DType,
349        weight_pack_factor: usize,
350    ) -> Result<Vec<usize>> {
351        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
352    }
353    fn mapped_max_act_size_elems(
354        &self,
355        config: &str,
356        params: &super::AutoDeviceMapParams,
357        prompt_chunksize: usize,
358    ) -> Result<usize> {
359        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
360    }
361    fn non_mapped_max_act_size_elems(
362        &self,
363        _config: &str,
364        _params: &AutoDeviceMapParams,
365    ) -> Result<usize> {
366        Ok(0)
367    }
368    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
369        Self::get_loader(config)?.model_config(config)
370    }
371}
372
373serde_default_fn!(bool, word_emb_default, false);
374
375// ======================== Mistral loader
376
377#[derive(Deserialize, Debug)]
378struct MistralBasicConfig {
379    vocab_size: usize,
380    hidden_size: usize,
381    intermediate_size: usize,
382    num_hidden_layers: usize,
383    num_attention_heads: usize,
384    num_key_value_heads: usize,
385    hidden_act: Activation,
386    max_position_embeddings: usize,
387    rms_norm_eps: f64,
388    rope_theta: f64,
389    sliding_window: Option<usize>,
390    head_dim: Option<usize>,
391    quantization_config: Option<QuantizedConfig>,
392    #[serde(default = "word_emb_default")]
393    tie_word_embeddings: bool,
394}
395
396impl MistralBasicConfig {
397    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mistral::Config> {
398        let basic_config: Self = serde_json::from_str(slice)?;
399        Ok(models::mistral::Config {
400            vocab_size: basic_config.vocab_size,
401            hidden_size: basic_config.hidden_size,
402            intermediate_size: basic_config.intermediate_size,
403            num_hidden_layers: basic_config.num_hidden_layers,
404            num_attention_heads: basic_config.num_attention_heads,
405            num_key_value_heads: basic_config.num_key_value_heads,
406            hidden_act: basic_config.hidden_act,
407            max_position_embeddings: basic_config.max_position_embeddings,
408            rms_norm_eps: basic_config.rms_norm_eps,
409            rope_theta: basic_config.rope_theta,
410            sliding_window: basic_config.sliding_window,
411            use_flash_attn,
412            head_dim: basic_config.head_dim,
413            quantization_config: basic_config.quantization_config,
414            tie_word_embeddings: basic_config.tie_word_embeddings,
415        })
416    }
417}
418
419pub struct MistralLoader;
420
421impl NormalModelLoader for MistralLoader {
422    fn load(
423        &self,
424        config: &str,
425        use_flash_attn: bool,
426        vb: ShardedVarBuilder,
427        normal_loading_metadata: NormalLoadingMetadata,
428        attention_mechanism: AttentionImplementation,
429    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
430        Ok(Box::new(models::mistral::Model::new(
431            &MistralBasicConfig::deserialize(config, use_flash_attn)?,
432            vb,
433            self.is_gptx(config)?,
434            normal_loading_metadata,
435            attention_mechanism,
436        )?))
437    }
438    fn load_xlora(
439        &self,
440        config: &str,
441        use_flash_attn: bool,
442        vb: ShardedVarBuilder,
443        lora_config: &[((String, String), LoraConfig)],
444        xlora_config: Option<XLoraConfig>,
445        xlora_ordering: Ordering,
446        normal_loading_metadata: NormalLoadingMetadata,
447        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
448    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
449        Ok(Box::new(xlora_models::XLoraMistral::new(
450            &MistralBasicConfig::deserialize(config, use_flash_attn)?,
451            vb,
452            lora_config,
453            xlora_config,
454            xlora_ordering,
455            self.is_gptx(config)?,
456            normal_loading_metadata,
457            preload_adapters,
458        )?))
459    }
460    fn is_gptx(&self, _: &str) -> Result<bool> {
461        Ok(true)
462    }
463    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
464        Ok(Box::new(MistralBasicConfig::deserialize(
465            config,
466            use_flash_attn,
467        )?))
468    }
469}
470
471impl IsqModelLoader for MistralLoader {
472    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
473        Ok(vec![
474            Regex::new(r"lm_head\.(weight|bias)$")?,
475            // Attention
476            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
477            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
478            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
479            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
480            // MLP
481            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
482            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
483            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
484        ])
485    }
486}
487
488impl DeviceMappedModelLoader for MistralLoader {
489    fn mapped_max_act_size_elems(
490        &self,
491        config: &str,
492        params: &AutoDeviceMapParams,
493        prompt_chunksize: usize,
494    ) -> Result<usize> {
495        let AutoDeviceMapParams::Text {
496            max_seq_len: _,
497            max_batch_size,
498        } = params
499        else {
500            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
501        };
502
503        let cfg = MistralBasicConfig::deserialize(config, false)?;
504
505        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
506    }
507    fn non_mapped_max_act_size_elems(
508        &self,
509        _config: &str,
510        _params: &AutoDeviceMapParams,
511    ) -> Result<usize> {
512        Ok(0)
513    }
514
515    fn non_mapped_size_in_bytes(
516        &self,
517        config: &str,
518        dtype: DType,
519        weight_pack_factor: usize,
520    ) -> Result<usize> {
521        let cfg = MistralBasicConfig::deserialize(config, false)?;
522        let elems = {
523            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
524            let lm_head = if !cfg.tie_word_embeddings {
525                cfg.hidden_size * cfg.vocab_size
526            } else {
527                0
528            };
529            let norm = cfg.hidden_size;
530            embed_tokens + lm_head + norm
531        };
532        Ok(elems * dtype.size_in_bytes())
533    }
534
535    fn layer_sizes_in_bytes(
536        &self,
537        config: &str,
538        dtype: DType,
539        weight_pack_factor: usize,
540    ) -> Result<Vec<usize>> {
541        let cfg = MistralBasicConfig::deserialize(config, false)?;
542        let per_layer_elems = {
543            let input_layernorm = cfg.hidden_size;
544            let post_attention_layernorm = cfg.hidden_size;
545
546            let size_in = cfg.hidden_size;
547            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
548            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
549            let q_proj = size_in * size_q / weight_pack_factor;
550            let k_proj = size_in * size_kv / weight_pack_factor;
551            let v_proj = size_in * size_kv / weight_pack_factor;
552            let o_proj = size_q * size_in / weight_pack_factor;
553
554            let h_size = cfg.hidden_size;
555            let i_size = cfg.intermediate_size;
556            let gate_proj = h_size * i_size / weight_pack_factor;
557            let up_proj = h_size * i_size / weight_pack_factor;
558            let down_proj = i_size * h_size / weight_pack_factor;
559
560            input_layernorm
561                + post_attention_layernorm
562                + q_proj
563                + k_proj
564                + v_proj
565                + o_proj
566                + gate_proj
567                + up_proj
568                + down_proj
569        };
570        Ok(vec![
571            per_layer_elems * dtype.size_in_bytes();
572            cfg.num_hidden_layers
573        ])
574    }
575
576    fn num_layers(&self, config: &str) -> Result<usize> {
577        let cfg = MistralBasicConfig::deserialize(config, false)?;
578        Ok(cfg.num_hidden_layers)
579    }
580
581    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
582        let cfg = MistralBasicConfig::deserialize(config, false)?;
583
584        let cfg = ModelConfigMetadata {
585            max_seq_len: cfg.max_position_embeddings,
586            num_layers: cfg.num_hidden_layers,
587            hidden_size: cfg.hidden_size,
588            num_kv_heads: cfg.num_key_value_heads,
589            num_attn_heads: cfg.num_attention_heads,
590            sliding_window: cfg.sliding_window,
591            k_head_dim: cfg.head_dim(),
592            v_head_dim: cfg.head_dim(),
593        };
594
595        Ok(Box::new(cfg))
596    }
597}
598
599// ======================== Gemma loader
600
601fn default_max_position_embeddings() -> usize {
602    4096
603}
604
605#[derive(Deserialize)]
606struct GemmaBasicConfig {
607    attention_bias: bool,
608    head_dim: usize,
609    // The code gemma configs include both hidden_act and hidden_activation.
610    hidden_act: Option<Activation>,
611    hidden_activation: Option<Activation>,
612    hidden_size: usize,
613    intermediate_size: usize,
614    num_attention_heads: usize,
615    num_hidden_layers: usize,
616    num_key_value_heads: usize,
617    rms_norm_eps: f64,
618    rope_theta: f64,
619    vocab_size: usize,
620
621    #[serde(default = "default_max_position_embeddings")]
622    max_position_embeddings: usize,
623    quantization_config: Option<QuantizedConfig>,
624    #[serde(default = "word_emb_default")]
625    tie_word_embeddings: bool,
626}
627
628impl GemmaBasicConfig {
629    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma::Config> {
630        let basic_config: Self = serde_json::from_str(slice)?;
631        Ok(models::gemma::Config {
632            vocab_size: basic_config.vocab_size,
633            hidden_size: basic_config.hidden_size,
634            intermediate_size: basic_config.intermediate_size,
635            num_hidden_layers: basic_config.num_hidden_layers,
636            num_attention_heads: basic_config.num_attention_heads,
637            num_key_value_heads: basic_config.num_key_value_heads,
638            hidden_act: basic_config.hidden_act,
639            hidden_activation: basic_config.hidden_activation,
640            max_position_embeddings: basic_config.max_position_embeddings,
641            rms_norm_eps: basic_config.rms_norm_eps,
642            rope_theta: basic_config.rope_theta,
643            attention_bias: basic_config.attention_bias,
644            head_dim: basic_config.head_dim,
645            use_flash_attn,
646            quantization_config: basic_config.quantization_config,
647            tie_word_embeddings: basic_config.tie_word_embeddings,
648        })
649    }
650}
651
652/// [`NormalLoader`] for a Gemma model.
653///
654/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
655pub struct GemmaLoader;
656
657impl NormalModelLoader for GemmaLoader {
658    fn load(
659        &self,
660        config: &str,
661        use_flash_attn: bool,
662        vb: ShardedVarBuilder,
663        normal_loading_metadata: NormalLoadingMetadata,
664        attention_mechanism: AttentionImplementation,
665    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
666        Ok(Box::new(models::gemma::Model::new(
667            &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
668            vb,
669            self.is_gptx(config)?,
670            normal_loading_metadata,
671            attention_mechanism,
672        )?))
673    }
674    fn load_xlora(
675        &self,
676        config: &str,
677        use_flash_attn: bool,
678        vb: ShardedVarBuilder,
679        lora_config: &[((String, String), LoraConfig)],
680        xlora_config: Option<XLoraConfig>,
681        xlora_ordering: Ordering,
682        normal_loading_metadata: NormalLoadingMetadata,
683        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
684    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
685        Ok(Box::new(xlora_models::XLoraGemma::new(
686            &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
687            vb,
688            lora_config,
689            xlora_config,
690            xlora_ordering,
691            self.is_gptx(config)?,
692            normal_loading_metadata,
693            preload_adapters,
694        )?))
695    }
696    fn is_gptx(&self, _: &str) -> Result<bool> {
697        Ok(true)
698    }
699    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
700        Ok(Box::new(GemmaBasicConfig::deserialize(
701            config,
702            use_flash_attn,
703        )?))
704    }
705}
706
707impl IsqModelLoader for GemmaLoader {
708    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
709        Ok(vec![
710            Regex::new(r"lm_head\.(weight|bias)$")?,
711            // Attention
712            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
713            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
714            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
715            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
716            // MLP
717            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
718            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
719            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
720        ])
721    }
722}
723
724impl DeviceMappedModelLoader for GemmaLoader {
725    fn mapped_max_act_size_elems(
726        &self,
727        config: &str,
728        params: &AutoDeviceMapParams,
729        prompt_chunksize: usize,
730    ) -> Result<usize> {
731        let AutoDeviceMapParams::Text {
732            max_seq_len: _,
733            max_batch_size,
734        } = params
735        else {
736            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
737        };
738
739        let cfg = GemmaBasicConfig::deserialize(config, false)?;
740
741        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
742    }
743    fn non_mapped_max_act_size_elems(
744        &self,
745        _config: &str,
746        _params: &AutoDeviceMapParams,
747    ) -> Result<usize> {
748        Ok(0)
749    }
750
751    fn non_mapped_size_in_bytes(
752        &self,
753        config: &str,
754        dtype: DType,
755        weight_pack_factor: usize,
756    ) -> Result<usize> {
757        let cfg = GemmaBasicConfig::deserialize(config, false)?;
758        let elems = {
759            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
760            let lm_head = if !cfg.tie_word_embeddings {
761                cfg.hidden_size * cfg.vocab_size
762            } else {
763                0
764            };
765            let norm = cfg.hidden_size;
766            embed_tokens + lm_head + norm
767        };
768        Ok(elems * dtype.size_in_bytes())
769    }
770
771    fn layer_sizes_in_bytes(
772        &self,
773        config: &str,
774        dtype: DType,
775        weight_pack_factor: usize,
776    ) -> Result<Vec<usize>> {
777        let cfg = GemmaBasicConfig::deserialize(config, false)?;
778        let per_layer_elems = {
779            let input_layernorm = cfg.hidden_size;
780            let post_attention_layernorm = cfg.hidden_size;
781
782            let size_in = cfg.hidden_size;
783            let size_q = cfg.head_dim * cfg.num_attention_heads;
784            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
785            let q_proj =
786                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
787            let k_proj =
788                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
789            let v_proj =
790                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
791            let o_proj =
792                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
793
794            let h_size = cfg.hidden_size;
795            let i_size = cfg.intermediate_size;
796            let gate_proj = h_size * i_size / weight_pack_factor;
797            let up_proj = h_size * i_size / weight_pack_factor;
798            let down_proj = i_size * h_size / weight_pack_factor;
799
800            input_layernorm
801                + post_attention_layernorm
802                + q_proj
803                + k_proj
804                + v_proj
805                + o_proj
806                + gate_proj
807                + up_proj
808                + down_proj
809        };
810        Ok(vec![
811            per_layer_elems * dtype.size_in_bytes();
812            cfg.num_hidden_layers
813        ])
814    }
815
816    fn num_layers(&self, config: &str) -> Result<usize> {
817        let cfg = GemmaBasicConfig::deserialize(config, false)?;
818        Ok(cfg.num_hidden_layers)
819    }
820
821    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
822        let cfg = GemmaBasicConfig::deserialize(config, false)?;
823
824        let cfg = ModelConfigMetadata {
825            max_seq_len: cfg.max_position_embeddings,
826            num_layers: cfg.num_hidden_layers,
827            hidden_size: cfg.hidden_size,
828            num_kv_heads: cfg.num_key_value_heads,
829            num_attn_heads: cfg.num_attention_heads,
830            sliding_window: None,
831            k_head_dim: cfg.head_dim,
832            v_head_dim: cfg.head_dim,
833        };
834
835        Ok(Box::new(cfg))
836    }
837}
838
839// ======================== Llama loader
840
841#[derive(Deserialize)]
842struct LlamaBasicConfig {
843    hidden_act: Activation,
844    hidden_size: usize,
845    intermediate_size: usize,
846    vocab_size: usize,
847    num_hidden_layers: usize,
848    num_attention_heads: usize,
849    num_key_value_heads: Option<usize>,
850    rms_norm_eps: f64,
851    #[serde(default = "default_rope")]
852    rope_theta: f32,
853    max_position_embeddings: usize,
854    rope_scaling: Option<Llama3RopeConfig>,
855    quantization_config: Option<QuantizedConfig>,
856    #[serde(default = "word_emb_default")]
857    tie_word_embeddings: bool,
858}
859
860fn default_rope() -> f32 {
861    10_000.0
862}
863
864impl LlamaBasicConfig {
865    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::llama::Config> {
866        let basic_config: Self = serde_json::from_str(slice)?;
867        Ok(models::llama::Config {
868            hidden_size: basic_config.hidden_size,
869            intermediate_size: basic_config.intermediate_size,
870            vocab_size: basic_config.vocab_size,
871            num_hidden_layers: basic_config.num_hidden_layers,
872            num_attention_heads: basic_config.num_attention_heads,
873            num_key_value_heads: basic_config
874                .num_key_value_heads
875                .unwrap_or(basic_config.num_attention_heads),
876            rms_norm_eps: basic_config.rms_norm_eps,
877            rope_theta: basic_config.rope_theta,
878            use_flash_attn,
879            max_position_embeddings: basic_config.max_position_embeddings,
880            rope_scaling: basic_config.rope_scaling,
881            quantization_config: basic_config.quantization_config,
882            tie_word_embeddings: basic_config.tie_word_embeddings,
883            hidden_act: basic_config.hidden_act,
884        })
885    }
886}
887
888/// [`NormalLoader`] for a Llama model.
889///
890/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
891pub struct LlamaLoader;
892
893impl NormalModelLoader for LlamaLoader {
894    fn load(
895        &self,
896        config: &str,
897        use_flash_attn: bool,
898        vb: ShardedVarBuilder,
899        normal_loading_metadata: NormalLoadingMetadata,
900        attention_mechanism: AttentionImplementation,
901    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
902        Ok(Box::new(models::llama::Llama::new(
903            &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
904            vb,
905            self.is_gptx(config)?,
906            normal_loading_metadata,
907            attention_mechanism,
908        )?))
909    }
910    fn load_xlora(
911        &self,
912        config: &str,
913        use_flash_attn: bool,
914        vb: ShardedVarBuilder,
915        lora_config: &[((String, String), LoraConfig)],
916        xlora_config: Option<XLoraConfig>,
917        xlora_ordering: Ordering,
918        normal_loading_metadata: NormalLoadingMetadata,
919        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
920    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
921        Ok(Box::new(xlora_models::XLoraLlama::new(
922            &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
923            vb,
924            lora_config,
925            xlora_config,
926            xlora_ordering,
927            self.is_gptx(config)?,
928            normal_loading_metadata,
929            preload_adapters,
930        )?))
931    }
932    fn is_gptx(&self, _: &str) -> Result<bool> {
933        Ok(true)
934    }
935    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
936        Ok(Box::new(LlamaBasicConfig::deserialize(
937            config,
938            use_flash_attn,
939        )?))
940    }
941}
942
943impl IsqModelLoader for LlamaLoader {
944    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
945        Ok(vec![
946            Regex::new(r"lm_head\.(weight|bias)$")?,
947            // Attention
948            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
949            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
950            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
951            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
952            // MLP
953            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
954            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
955            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
956        ])
957    }
958}
959
960impl DeviceMappedModelLoader for LlamaLoader {
961    fn mapped_max_act_size_elems(
962        &self,
963        config: &str,
964        params: &AutoDeviceMapParams,
965        prompt_chunksize: usize,
966    ) -> Result<usize> {
967        let AutoDeviceMapParams::Text {
968            max_seq_len: _,
969            max_batch_size,
970        } = params
971        else {
972            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
973        };
974
975        let cfg = LlamaBasicConfig::deserialize(config, false)?;
976
977        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
978    }
979    fn non_mapped_max_act_size_elems(
980        &self,
981        _config: &str,
982        _params: &AutoDeviceMapParams,
983    ) -> Result<usize> {
984        Ok(0)
985    }
986
987    fn non_mapped_size_in_bytes(
988        &self,
989        config: &str,
990        dtype: DType,
991        weight_pack_factor: usize,
992    ) -> Result<usize> {
993        let cfg = LlamaBasicConfig::deserialize(config, false)?;
994        let elems = {
995            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
996            let lm_head = if !cfg.tie_word_embeddings {
997                cfg.hidden_size * cfg.vocab_size
998            } else {
999                0
1000            };
1001            let norm = cfg.hidden_size;
1002            embed_tokens + lm_head + norm
1003        };
1004        Ok(elems * dtype.size_in_bytes())
1005    }
1006
1007    fn layer_sizes_in_bytes(
1008        &self,
1009        config: &str,
1010        dtype: DType,
1011        weight_pack_factor: usize,
1012    ) -> Result<Vec<usize>> {
1013        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1014        let per_layer_elems = {
1015            let input_layernorm = cfg.hidden_size;
1016            let post_attention_layernorm = cfg.hidden_size;
1017
1018            let size_in = cfg.hidden_size;
1019            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1020            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1021            let q_proj = size_in * size_q / weight_pack_factor;
1022            let k_proj = size_in * size_kv / weight_pack_factor;
1023            let v_proj = size_in * size_kv / weight_pack_factor;
1024            let o_proj = size_q * size_in / weight_pack_factor;
1025
1026            let h_size = cfg.hidden_size;
1027            let i_size = cfg.intermediate_size;
1028            let gate_proj = h_size * i_size / weight_pack_factor;
1029            let up_proj = h_size * i_size / weight_pack_factor;
1030            let down_proj = i_size * h_size / weight_pack_factor;
1031
1032            input_layernorm
1033                + post_attention_layernorm
1034                + q_proj
1035                + k_proj
1036                + v_proj
1037                + o_proj
1038                + gate_proj
1039                + up_proj
1040                + down_proj
1041        };
1042        Ok(vec![
1043            per_layer_elems * dtype.size_in_bytes();
1044            cfg.num_hidden_layers
1045        ])
1046    }
1047
1048    fn num_layers(&self, config: &str) -> Result<usize> {
1049        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1050        Ok(cfg.num_hidden_layers)
1051    }
1052    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1053        let cfg = LlamaBasicConfig::deserialize(config, false)?;
1054
1055        let cfg = ModelConfigMetadata {
1056            max_seq_len: cfg.max_position_embeddings,
1057            num_layers: cfg.num_hidden_layers,
1058            hidden_size: cfg.hidden_size,
1059            num_kv_heads: cfg.num_key_value_heads,
1060            num_attn_heads: cfg.num_attention_heads,
1061            sliding_window: None,
1062            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1063            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1064        };
1065
1066        Ok(Box::new(cfg))
1067    }
1068}
1069
1070// ======================== Mixtral loader
1071
1072#[derive(Deserialize)]
1073struct MixtralBasicConfig {
1074    vocab_size: usize,
1075    hidden_size: usize,
1076    intermediate_size: usize,
1077    num_hidden_layers: usize,
1078    num_attention_heads: usize,
1079    num_key_value_heads: usize,
1080    hidden_act: Activation,
1081    max_position_embeddings: usize,
1082    rms_norm_eps: f64,
1083    rope_theta: f64,
1084    sliding_window: Option<usize>,
1085    num_experts_per_tok: usize,
1086    num_local_experts: usize,
1087    quantization_config: Option<QuantizedConfig>,
1088    #[serde(default = "word_emb_default")]
1089    tie_word_embeddings: bool,
1090}
1091
1092impl MixtralBasicConfig {
1093    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mixtral::Config> {
1094        let basic_config: Self = serde_json::from_str(slice)?;
1095        Ok(models::mixtral::Config {
1096            vocab_size: basic_config.vocab_size,
1097            hidden_size: basic_config.hidden_size,
1098            intermediate_size: basic_config.intermediate_size,
1099            num_hidden_layers: basic_config.num_hidden_layers,
1100            num_attention_heads: basic_config.num_attention_heads,
1101            num_key_value_heads: basic_config.num_key_value_heads,
1102            hidden_act: basic_config.hidden_act,
1103            max_position_embeddings: basic_config.max_position_embeddings,
1104            rms_norm_eps: basic_config.rms_norm_eps,
1105            rope_theta: basic_config.rope_theta,
1106            sliding_window: basic_config.sliding_window,
1107            use_flash_attn,
1108            num_experts_per_tok: basic_config.num_experts_per_tok,
1109            num_local_experts: basic_config.num_local_experts,
1110            quantization_config: basic_config.quantization_config,
1111            tie_word_embeddings: basic_config.tie_word_embeddings,
1112        })
1113    }
1114}
1115
1116pub struct MixtralLoader;
1117
1118impl NormalModelLoader for MixtralLoader {
1119    fn load(
1120        &self,
1121        config: &str,
1122        use_flash_attn: bool,
1123        vb: ShardedVarBuilder,
1124        normal_loading_metadata: NormalLoadingMetadata,
1125        attention_mechanism: AttentionImplementation,
1126    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1127        Ok(Box::new(models::mixtral::Model::new(
1128            &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1129            vb,
1130            self.is_gptx(config)?,
1131            normal_loading_metadata,
1132            attention_mechanism,
1133        )?))
1134    }
1135    fn load_xlora(
1136        &self,
1137        config: &str,
1138        use_flash_attn: bool,
1139        vb: ShardedVarBuilder,
1140        lora_config: &[((String, String), LoraConfig)],
1141        xlora_config: Option<XLoraConfig>,
1142        xlora_ordering: Ordering,
1143        normal_loading_metadata: NormalLoadingMetadata,
1144        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1145    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1146        Ok(Box::new(xlora_models::XLoraMixtral::new(
1147            &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1148            vb,
1149            lora_config,
1150            xlora_config,
1151            xlora_ordering,
1152            self.is_gptx(config)?,
1153            normal_loading_metadata,
1154            preload_adapters,
1155        )?))
1156    }
1157    fn is_gptx(&self, _: &str) -> Result<bool> {
1158        Ok(true)
1159    }
1160    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1161        Ok(Box::new(MixtralBasicConfig::deserialize(
1162            config,
1163            use_flash_attn,
1164        )?))
1165    }
1166}
1167
1168impl IsqModelLoader for MixtralLoader {
1169    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1170        Ok(vec![
1171            Regex::new(r"lm_head\.(weight|bias)$")?,
1172            // Attention
1173            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1174            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1175            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1176            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1177            // Experts
1178            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1179            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1180            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1181            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1182        ])
1183    }
1184}
1185
1186impl DeviceMappedModelLoader for MixtralLoader {
1187    fn mapped_max_act_size_elems(
1188        &self,
1189        config: &str,
1190        params: &AutoDeviceMapParams,
1191        prompt_chunksize: usize,
1192    ) -> Result<usize> {
1193        let AutoDeviceMapParams::Text {
1194            max_seq_len: _,
1195            max_batch_size,
1196        } = params
1197        else {
1198            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1199        };
1200
1201        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1202
1203        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1204    }
1205    fn non_mapped_max_act_size_elems(
1206        &self,
1207        _config: &str,
1208        _params: &AutoDeviceMapParams,
1209    ) -> Result<usize> {
1210        Ok(0)
1211    }
1212
1213    fn non_mapped_size_in_bytes(
1214        &self,
1215        config: &str,
1216        dtype: DType,
1217        weight_pack_factor: usize,
1218    ) -> Result<usize> {
1219        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1220        let elems = {
1221            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1222            let lm_head = if !cfg.tie_word_embeddings {
1223                cfg.hidden_size * cfg.vocab_size
1224            } else {
1225                0
1226            };
1227            let norm = cfg.hidden_size;
1228            embed_tokens + lm_head + norm
1229        };
1230        Ok(elems * dtype.size_in_bytes())
1231    }
1232
1233    fn layer_sizes_in_bytes(
1234        &self,
1235        config: &str,
1236        dtype: DType,
1237        weight_pack_factor: usize,
1238    ) -> Result<Vec<usize>> {
1239        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1240        let per_layer_elems = {
1241            let input_layernorm = cfg.hidden_size;
1242            let post_attention_layernorm = cfg.hidden_size;
1243
1244            let size_in = cfg.hidden_size;
1245            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1246            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1247            let q_proj = size_in * size_q / weight_pack_factor;
1248            let k_proj = size_in * size_kv / weight_pack_factor;
1249            let v_proj = size_in * size_kv / weight_pack_factor;
1250            let o_proj = size_q * size_in / weight_pack_factor;
1251
1252            let moe_block = {
1253                let gate = cfg.hidden_size * cfg.num_local_experts;
1254                // Assume quantizing weight pack factor
1255                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1256                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1257                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1258                gate + cfg.num_local_experts * w1
1259                    + cfg.num_local_experts * w2
1260                    + cfg.num_local_experts * w3
1261            };
1262
1263            input_layernorm
1264                + post_attention_layernorm
1265                + q_proj
1266                + k_proj
1267                + v_proj
1268                + o_proj
1269                + moe_block
1270        };
1271        Ok(vec![
1272            per_layer_elems * dtype.size_in_bytes();
1273            cfg.num_hidden_layers
1274        ])
1275    }
1276
1277    fn num_layers(&self, config: &str) -> Result<usize> {
1278        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1279        Ok(cfg.num_hidden_layers)
1280    }
1281
1282    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1283        let cfg = MixtralBasicConfig::deserialize(config, false)?;
1284
1285        let cfg = ModelConfigMetadata {
1286            max_seq_len: cfg.max_position_embeddings,
1287            num_layers: cfg.num_hidden_layers,
1288            hidden_size: cfg.hidden_size,
1289            num_kv_heads: cfg.num_key_value_heads,
1290            num_attn_heads: cfg.num_attention_heads,
1291            sliding_window: cfg.sliding_window,
1292            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1293            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1294        };
1295
1296        Ok(Box::new(cfg))
1297    }
1298}
1299
1300// ======================== Phi2 loader
1301
1302#[derive(Deserialize)]
1303struct Phi2BasicConfig {
1304    vocab_size: usize,
1305    hidden_size: usize,
1306    intermediate_size: usize,
1307    num_hidden_layers: usize,
1308    num_attention_heads: usize,
1309    num_key_value_heads: Option<usize>,
1310    hidden_act: Activation,
1311    max_position_embeddings: usize,
1312    layer_norm_eps: f64,
1313    rope_theta: f32,
1314    partial_rotary_factor: f64,
1315    qk_layernorm: bool,
1316    quantization_config: Option<QuantizedConfig>,
1317    #[serde(default = "word_emb_default")]
1318    tie_word_embeddings: bool,
1319}
1320
1321impl Phi2BasicConfig {
1322    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi2::Config> {
1323        let basic_config: Self = serde_json::from_str(slice)?;
1324        Ok(models::phi2::Config {
1325            vocab_size: basic_config.vocab_size,
1326            hidden_size: basic_config.hidden_size,
1327            intermediate_size: basic_config.intermediate_size,
1328            num_hidden_layers: basic_config.num_hidden_layers,
1329            num_attention_heads: basic_config.num_attention_heads,
1330            num_key_value_heads: basic_config.num_key_value_heads,
1331            hidden_act: basic_config.hidden_act,
1332            max_position_embeddings: basic_config.max_position_embeddings,
1333            rope_theta: basic_config.rope_theta,
1334            layer_norm_eps: basic_config.layer_norm_eps,
1335            partial_rotary_factor: basic_config.partial_rotary_factor,
1336            qk_layernorm: basic_config.qk_layernorm,
1337            use_flash_attn,
1338            quantization_config: basic_config.quantization_config,
1339            tie_word_embeddings: basic_config.tie_word_embeddings,
1340        })
1341    }
1342}
1343
1344/// [`NormalLoader`] for a Phi 2 model.
1345///
1346/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1347pub struct Phi2Loader;
1348
1349impl NormalModelLoader for Phi2Loader {
1350    fn load(
1351        &self,
1352        config: &str,
1353        use_flash_attn: bool,
1354        vb: ShardedVarBuilder,
1355        normal_loading_metadata: NormalLoadingMetadata,
1356        attention_mechanism: AttentionImplementation,
1357    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1358        Ok(Box::new(models::phi2::Model::new(
1359            &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1360            vb,
1361            self.is_gptx(config)?,
1362            normal_loading_metadata,
1363            attention_mechanism,
1364        )?))
1365    }
1366    fn load_xlora(
1367        &self,
1368        config: &str,
1369        use_flash_attn: bool,
1370        vb: ShardedVarBuilder,
1371        lora_config: &[((String, String), LoraConfig)],
1372        xlora_config: Option<XLoraConfig>,
1373        xlora_ordering: Ordering,
1374        normal_loading_metadata: NormalLoadingMetadata,
1375        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1376    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1377        Ok(Box::new(xlora_models::XLoraPhi2::new(
1378            &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1379            vb,
1380            lora_config,
1381            xlora_config,
1382            xlora_ordering,
1383            self.is_gptx(config)?,
1384            normal_loading_metadata,
1385            preload_adapters,
1386        )?))
1387    }
1388    fn is_gptx(&self, _: &str) -> Result<bool> {
1389        Ok(true)
1390    }
1391    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1392        Ok(Box::new(Phi2BasicConfig::deserialize(
1393            config,
1394            use_flash_attn,
1395        )?))
1396    }
1397}
1398
1399impl IsqModelLoader for Phi2Loader {
1400    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1401        Ok(vec![
1402            Regex::new(r"lm_head\.(weight|bias)$")?,
1403            // Attention
1404            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1405            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1406            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1407            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1408            // MLP
1409            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1410            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1411        ])
1412    }
1413}
1414
1415impl DeviceMappedModelLoader for Phi2Loader {
1416    fn mapped_max_act_size_elems(
1417        &self,
1418        config: &str,
1419        params: &AutoDeviceMapParams,
1420        prompt_chunksize: usize,
1421    ) -> Result<usize> {
1422        let AutoDeviceMapParams::Text {
1423            max_seq_len: _,
1424            max_batch_size,
1425        } = params
1426        else {
1427            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1428        };
1429
1430        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1431
1432        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1433    }
1434    fn non_mapped_max_act_size_elems(
1435        &self,
1436        _config: &str,
1437        _params: &AutoDeviceMapParams,
1438    ) -> Result<usize> {
1439        Ok(0)
1440    }
1441
1442    fn non_mapped_size_in_bytes(
1443        &self,
1444        config: &str,
1445        dtype: DType,
1446        weight_pack_factor: usize,
1447    ) -> Result<usize> {
1448        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1449        let elems = {
1450            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1451            let lm_head = if !cfg.tie_word_embeddings {
1452                cfg.hidden_size * cfg.vocab_size
1453            } else {
1454                0
1455            };
1456            let norm = cfg.hidden_size;
1457            embed_tokens + lm_head + norm
1458        };
1459        Ok(elems * dtype.size_in_bytes())
1460    }
1461
1462    fn layer_sizes_in_bytes(
1463        &self,
1464        config: &str,
1465        dtype: DType,
1466        weight_pack_factor: usize,
1467    ) -> Result<Vec<usize>> {
1468        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1469        let per_layer_elems = {
1470            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1471
1472            let size_in = cfg.hidden_size;
1473            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1474            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1475            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1476            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1477            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1478            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1479            let (q_norm, k_norm) = if cfg.qk_layernorm {
1480                (cfg.head_dim(), cfg.head_dim())
1481            } else {
1482                (0, 0)
1483            };
1484
1485            let h_size = cfg.hidden_size;
1486            let i_size = cfg.intermediate_size;
1487            let fc1 = h_size * i_size / weight_pack_factor;
1488            let fc2 = h_size * i_size / weight_pack_factor;
1489
1490            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1491        };
1492        Ok(vec![
1493            per_layer_elems * dtype.size_in_bytes();
1494            cfg.num_hidden_layers
1495        ])
1496    }
1497
1498    fn num_layers(&self, config: &str) -> Result<usize> {
1499        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1500        Ok(cfg.num_hidden_layers)
1501    }
1502
1503    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1504        let cfg = Phi2BasicConfig::deserialize(config, false)?;
1505
1506        let cfg = ModelConfigMetadata {
1507            max_seq_len: cfg.max_position_embeddings,
1508            num_layers: cfg.num_hidden_layers,
1509            hidden_size: cfg.hidden_size,
1510            num_kv_heads: cfg.num_key_value_heads(),
1511            num_attn_heads: cfg.num_attention_heads,
1512            sliding_window: None,
1513            k_head_dim: cfg.head_dim(),
1514            v_head_dim: cfg.head_dim(),
1515        };
1516
1517        Ok(Box::new(cfg))
1518    }
1519}
1520
1521// ======================== Phi3 loader
1522
1523#[derive(Deserialize)]
1524struct Phi3BasicConfig {
1525    vocab_size: usize,
1526    hidden_act: Activation,
1527    hidden_size: usize,
1528    intermediate_size: usize,
1529    num_hidden_layers: usize,
1530    num_attention_heads: usize,
1531    num_key_value_heads: usize,
1532    rms_norm_eps: f64,
1533    rope_theta: f64,
1534    bos_token_id: Option<u32>,
1535    eos_token_id: Option<u32>,
1536    rope_scaling: Option<PhiRopeScalingConfig>,
1537    max_position_embeddings: usize,
1538    original_max_position_embeddings: usize,
1539    sliding_window: Option<usize>,
1540    quantization_config: Option<QuantizedConfig>,
1541    #[serde(default = "word_emb_default")]
1542    tie_word_embeddings: bool,
1543    partial_rotary_factor: Option<f64>,
1544}
1545
1546impl Phi3BasicConfig {
1547    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3::Config> {
1548        let basic_config: Self = serde_json::from_str(slice)?;
1549        Ok(models::phi3::Config {
1550            vocab_size: basic_config.vocab_size,
1551            hidden_size: basic_config.hidden_size,
1552            intermediate_size: basic_config.intermediate_size,
1553            num_hidden_layers: basic_config.num_hidden_layers,
1554            num_attention_heads: basic_config.num_attention_heads,
1555            num_key_value_heads: basic_config.num_key_value_heads,
1556            hidden_act: basic_config.hidden_act,
1557            max_position_embeddings: basic_config.max_position_embeddings,
1558            rope_theta: basic_config.rope_theta,
1559            rms_norm_eps: basic_config.rms_norm_eps,
1560            eos_token_id: basic_config.eos_token_id,
1561            bos_token_id: basic_config.bos_token_id,
1562            rope_scaling: basic_config.rope_scaling,
1563            original_max_position_embeddings: basic_config.original_max_position_embeddings,
1564            use_flash_attn,
1565            sliding_window: basic_config.sliding_window,
1566            quantization_config: basic_config.quantization_config,
1567            tie_word_embeddings: basic_config.tie_word_embeddings,
1568            partial_rotary_factor: basic_config.partial_rotary_factor,
1569        })
1570    }
1571}
1572
1573/// [`NormalLoader`] for a Phi 3 model.
1574///
1575/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1576pub struct Phi3Loader;
1577
1578impl NormalModelLoader for Phi3Loader {
1579    fn load(
1580        &self,
1581        config: &str,
1582        use_flash_attn: bool,
1583        vb: ShardedVarBuilder,
1584        normal_loading_metadata: NormalLoadingMetadata,
1585        attention_mechanism: AttentionImplementation,
1586    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1587        Ok(Box::new(models::phi3::Model::new(
1588            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1589            vb,
1590            self.is_gptx(config)?,
1591            normal_loading_metadata,
1592            attention_mechanism,
1593        )?))
1594    }
1595    fn load_xlora(
1596        &self,
1597        config: &str,
1598        use_flash_attn: bool,
1599        vb: ShardedVarBuilder,
1600        lora_config: &[((String, String), LoraConfig)],
1601        xlora_config: Option<XLoraConfig>,
1602        xlora_ordering: Ordering,
1603        normal_loading_metadata: NormalLoadingMetadata,
1604        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1605    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1606        Ok(Box::new(xlora_models::XLoraPhi3::new(
1607            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1608            vb,
1609            lora_config,
1610            xlora_config,
1611            xlora_ordering,
1612            self.is_gptx(config)?,
1613            normal_loading_metadata,
1614            preload_adapters,
1615        )?))
1616    }
1617    fn is_gptx(&self, _: &str) -> Result<bool> {
1618        Ok(true)
1619    }
1620    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1621        Ok(Box::new(Phi3BasicConfig::deserialize(
1622            config,
1623            use_flash_attn,
1624        )?))
1625    }
1626}
1627
1628impl IsqModelLoader for Phi3Loader {
1629    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1630        Ok(vec![
1631            Regex::new(r"lm_head\.(weight|bias)$")?,
1632            // Attention
1633            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1634            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1635            // MLP
1636            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1637            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1638            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1639        ])
1640    }
1641}
1642
1643impl DeviceMappedModelLoader for Phi3Loader {
1644    fn mapped_max_act_size_elems(
1645        &self,
1646        config: &str,
1647        params: &AutoDeviceMapParams,
1648        prompt_chunksize: usize,
1649    ) -> Result<usize> {
1650        let AutoDeviceMapParams::Text {
1651            max_seq_len: _,
1652            max_batch_size,
1653        } = params
1654        else {
1655            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1656        };
1657
1658        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1659
1660        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1661    }
1662    fn non_mapped_max_act_size_elems(
1663        &self,
1664        _config: &str,
1665        _params: &AutoDeviceMapParams,
1666    ) -> Result<usize> {
1667        Ok(0)
1668    }
1669
1670    fn non_mapped_size_in_bytes(
1671        &self,
1672        config: &str,
1673        dtype: DType,
1674        weight_pack_factor: usize,
1675    ) -> Result<usize> {
1676        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1677        let elems = {
1678            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1679            let lm_head = if !cfg.tie_word_embeddings {
1680                cfg.hidden_size * cfg.vocab_size
1681            } else {
1682                0
1683            };
1684            let norm = cfg.hidden_size;
1685            embed_tokens + lm_head + norm
1686        };
1687        Ok(elems * dtype.size_in_bytes())
1688    }
1689
1690    fn layer_sizes_in_bytes(
1691        &self,
1692        config: &str,
1693        dtype: DType,
1694        weight_pack_factor: usize,
1695    ) -> Result<Vec<usize>> {
1696        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1697        let per_layer_elems = {
1698            let input_layernorm = cfg.hidden_size;
1699            let post_attention_layernorm = cfg.hidden_size;
1700
1701            let size_in = cfg.hidden_size;
1702            let head_dim = cfg.head_dim();
1703            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1704            let qkv_proj = size_in * op_size / weight_pack_factor;
1705            let o_proj =
1706                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1707
1708            let h_size = cfg.hidden_size;
1709            let i_size = cfg.intermediate_size;
1710            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1711            let down_proj = h_size * i_size / weight_pack_factor;
1712
1713            input_layernorm
1714                + post_attention_layernorm
1715                + qkv_proj
1716                + o_proj
1717                + gate_up_proj
1718                + down_proj
1719        };
1720        Ok(vec![
1721            per_layer_elems * dtype.size_in_bytes();
1722            cfg.num_hidden_layers
1723        ])
1724    }
1725
1726    fn num_layers(&self, config: &str) -> Result<usize> {
1727        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1728        Ok(cfg.num_hidden_layers)
1729    }
1730
1731    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1732        let cfg = Phi3BasicConfig::deserialize(config, false)?;
1733
1734        let cfg = ModelConfigMetadata {
1735            max_seq_len: cfg.max_position_embeddings,
1736            num_layers: cfg.num_hidden_layers,
1737            hidden_size: cfg.hidden_size,
1738            num_kv_heads: cfg.num_key_value_heads,
1739            num_attn_heads: cfg.num_attention_heads,
1740            sliding_window: cfg.sliding_window,
1741            k_head_dim: cfg.head_dim(),
1742            v_head_dim: cfg.head_dim(),
1743        };
1744
1745        Ok(Box::new(cfg))
1746    }
1747}
1748
1749// ======================== Qwen2 loader
1750
1751#[derive(Deserialize)]
1752struct Qwen2BasicConfig {
1753    vocab_size: usize,
1754    hidden_size: usize,
1755    intermediate_size: usize,
1756    num_hidden_layers: usize,
1757    num_attention_heads: usize,
1758    num_key_value_heads: usize,
1759    max_position_embeddings: usize,
1760    sliding_window: usize,
1761    rope_theta: f64,
1762    rms_norm_eps: f64,
1763    hidden_act: Activation,
1764    quantization_config: Option<QuantizedConfig>,
1765    tie_word_embeddings: bool,
1766}
1767
1768impl Qwen2BasicConfig {
1769    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::qwen2::Config> {
1770        let basic_config: Self = serde_json::from_str(slice)?;
1771        Ok(models::qwen2::Config {
1772            vocab_size: basic_config.vocab_size,
1773            hidden_size: basic_config.hidden_size,
1774            intermediate_size: basic_config.intermediate_size,
1775            num_hidden_layers: basic_config.num_hidden_layers,
1776            num_attention_heads: basic_config.num_attention_heads,
1777            num_key_value_heads: basic_config.num_key_value_heads,
1778            hidden_act: basic_config.hidden_act,
1779            max_position_embeddings: basic_config.max_position_embeddings,
1780            rope_theta: basic_config.rope_theta,
1781            rms_norm_eps: basic_config.rms_norm_eps,
1782            sliding_window: basic_config.sliding_window,
1783            use_flash_attn,
1784            quantization_config: basic_config.quantization_config,
1785            tie_word_embeddings: basic_config.tie_word_embeddings,
1786        })
1787    }
1788}
1789
1790/// [`NormalLoader`] for a Qwen 2 model.
1791///
1792/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
1793pub struct Qwen2Loader;
1794
1795impl NormalModelLoader for Qwen2Loader {
1796    fn load(
1797        &self,
1798        config: &str,
1799        use_flash_attn: bool,
1800        vb: ShardedVarBuilder,
1801        normal_loading_metadata: NormalLoadingMetadata,
1802        attention_mechanism: AttentionImplementation,
1803    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1804        Ok(Box::new(models::qwen2::Model::new(
1805            &Qwen2BasicConfig::deserialize(config, use_flash_attn)?,
1806            vb,
1807            self.is_gptx(config)?,
1808            normal_loading_metadata,
1809            attention_mechanism,
1810        )?))
1811    }
1812    fn load_xlora(
1813        &self,
1814        _config: &str,
1815        _use_flash_attn: bool,
1816        _vb: ShardedVarBuilder,
1817        _lora_config: &[((String, String), LoraConfig)],
1818        _xlora_config: Option<XLoraConfig>,
1819        _xlora_ordering: Ordering,
1820        _normal_loading_metadata: NormalLoadingMetadata,
1821        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1822    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1823        todo!()
1824    }
1825    fn is_gptx(&self, _: &str) -> Result<bool> {
1826        Ok(true)
1827    }
1828    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1829        Ok(Box::new(Qwen2BasicConfig::deserialize(
1830            config,
1831            use_flash_attn,
1832        )?))
1833    }
1834}
1835
1836impl IsqModelLoader for Qwen2Loader {
1837    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1838        Ok(vec![
1839            Regex::new(r"lm_head\.(weight|bias)$")?,
1840            // Attention
1841            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1842            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1843            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1844            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1845            // MLP
1846            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1847            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1848            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1849        ])
1850    }
1851}
1852
1853impl DeviceMappedModelLoader for Qwen2Loader {
1854    fn mapped_max_act_size_elems(
1855        &self,
1856        config: &str,
1857        params: &AutoDeviceMapParams,
1858        prompt_chunksize: usize,
1859    ) -> Result<usize> {
1860        let AutoDeviceMapParams::Text {
1861            max_seq_len: _,
1862            max_batch_size,
1863        } = params
1864        else {
1865            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1866        };
1867
1868        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1869
1870        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1871    }
1872    fn non_mapped_max_act_size_elems(
1873        &self,
1874        _config: &str,
1875        _params: &AutoDeviceMapParams,
1876    ) -> Result<usize> {
1877        Ok(0)
1878    }
1879
1880    fn non_mapped_size_in_bytes(
1881        &self,
1882        config: &str,
1883        dtype: DType,
1884        weight_pack_factor: usize,
1885    ) -> Result<usize> {
1886        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1887        let elems = {
1888            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1889            let lm_head = if !cfg.tie_word_embeddings {
1890                cfg.hidden_size * cfg.vocab_size
1891            } else {
1892                0
1893            };
1894            let norm = cfg.hidden_size;
1895            embed_tokens + lm_head + norm
1896        };
1897        Ok(elems * dtype.size_in_bytes())
1898    }
1899
1900    fn layer_sizes_in_bytes(
1901        &self,
1902        config: &str,
1903        dtype: DType,
1904        weight_pack_factor: usize,
1905    ) -> Result<Vec<usize>> {
1906        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1907        let per_layer_elems = {
1908            let input_layernorm = cfg.hidden_size;
1909            let post_attention_layernorm = cfg.hidden_size;
1910
1911            let size_in = cfg.hidden_size;
1912            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1913            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1914            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1915            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1916            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1917            let o_proj = size_q * size_in / weight_pack_factor;
1918
1919            let h_size = cfg.hidden_size;
1920            let i_size = cfg.intermediate_size;
1921            let gate_proj = h_size * i_size / weight_pack_factor;
1922            let up_proj = h_size * i_size / weight_pack_factor;
1923            let down_proj = i_size * h_size / weight_pack_factor;
1924
1925            input_layernorm
1926                + post_attention_layernorm
1927                + q_proj
1928                + k_proj
1929                + v_proj
1930                + o_proj
1931                + gate_proj
1932                + up_proj
1933                + down_proj
1934        };
1935        Ok(vec![
1936            per_layer_elems * dtype.size_in_bytes();
1937            cfg.num_hidden_layers
1938        ])
1939    }
1940
1941    fn num_layers(&self, config: &str) -> Result<usize> {
1942        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1943        Ok(cfg.num_hidden_layers)
1944    }
1945
1946    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1947        let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1948
1949        let cfg = ModelConfigMetadata {
1950            max_seq_len: cfg.max_position_embeddings,
1951            num_layers: cfg.num_hidden_layers,
1952            hidden_size: cfg.hidden_size,
1953            num_kv_heads: cfg.num_key_value_heads,
1954            num_attn_heads: cfg.num_attention_heads,
1955            sliding_window: Some(cfg.sliding_window),
1956            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1957            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1958        };
1959
1960        Ok(Box::new(cfg))
1961    }
1962}
1963
1964// ======================== Gemma2 loader
1965
1966#[derive(Deserialize)]
1967struct Gemma2BasicConfig {
1968    attention_bias: bool,
1969    head_dim: usize,
1970    // The code gemma configs include both hidden_act and hidden_activation.
1971    hidden_act: Option<Activation>,
1972    hidden_activation: Option<Activation>,
1973    hidden_size: usize,
1974    intermediate_size: usize,
1975    num_attention_heads: usize,
1976    num_hidden_layers: usize,
1977    num_key_value_heads: usize,
1978    rms_norm_eps: f64,
1979    rope_theta: f64,
1980    vocab_size: usize,
1981    sliding_window: usize,
1982    attn_logit_softcapping: Option<f64>,
1983    final_logit_softcapping: Option<f64>,
1984    query_pre_attn_scalar: usize,
1985
1986    #[serde(default = "default_max_position_embeddings")]
1987    max_position_embeddings: usize,
1988    quantization_config: Option<QuantizedConfig>,
1989    #[serde(default = "word_emb_default")]
1990    tie_word_embeddings: bool,
1991}
1992
1993impl Gemma2BasicConfig {
1994    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma2::Config> {
1995        let basic_config: Self = serde_json::from_str(slice)?;
1996        Ok(models::gemma2::Config {
1997            vocab_size: basic_config.vocab_size,
1998            hidden_size: basic_config.hidden_size,
1999            intermediate_size: basic_config.intermediate_size,
2000            num_hidden_layers: basic_config.num_hidden_layers,
2001            num_attention_heads: basic_config.num_attention_heads,
2002            num_key_value_heads: basic_config.num_key_value_heads,
2003            hidden_act: basic_config.hidden_act,
2004            hidden_activation: basic_config.hidden_activation,
2005            max_position_embeddings: basic_config.max_position_embeddings,
2006            rms_norm_eps: basic_config.rms_norm_eps,
2007            rope_theta: basic_config.rope_theta,
2008            attention_bias: basic_config.attention_bias,
2009            head_dim: basic_config.head_dim,
2010            use_flash_attn,
2011            quantization_config: basic_config.quantization_config,
2012            sliding_window: basic_config.sliding_window,
2013            attn_logit_softcapping: basic_config.attn_logit_softcapping,
2014            final_logit_softcapping: basic_config.final_logit_softcapping,
2015            query_pre_attn_scalar: basic_config.query_pre_attn_scalar,
2016            tie_word_embeddings: basic_config.tie_word_embeddings,
2017        })
2018    }
2019}
2020
2021/// [`NormalLoader`] for a Gemma2 model.
2022///
2023/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2024pub struct Gemma2Loader;
2025
2026impl NormalModelLoader for Gemma2Loader {
2027    fn load(
2028        &self,
2029        config: &str,
2030        use_flash_attn: bool,
2031        vb: ShardedVarBuilder,
2032        normal_loading_metadata: NormalLoadingMetadata,
2033        attention_mechanism: AttentionImplementation,
2034    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2035        Ok(Box::new(models::gemma2::Model::new(
2036            &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2037            vb,
2038            self.is_gptx(config)?,
2039            normal_loading_metadata,
2040            attention_mechanism,
2041        )?))
2042    }
2043    fn load_xlora(
2044        &self,
2045        config: &str,
2046        use_flash_attn: bool,
2047        vb: ShardedVarBuilder,
2048        lora_config: &[((String, String), LoraConfig)],
2049        xlora_config: Option<XLoraConfig>,
2050        xlora_ordering: Ordering,
2051        normal_loading_metadata: NormalLoadingMetadata,
2052        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2053    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2054        Ok(Box::new(xlora_models::XLoraGemma2::new(
2055            &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2056            vb,
2057            lora_config,
2058            xlora_config,
2059            xlora_ordering,
2060            self.is_gptx(config)?,
2061            normal_loading_metadata,
2062            preload_adapters,
2063        )?))
2064    }
2065    fn is_gptx(&self, _: &str) -> Result<bool> {
2066        Ok(true)
2067    }
2068    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2069        // Already will warn about it
2070        Ok(Box::new(Gemma2BasicConfig::deserialize(
2071            config,
2072            use_flash_attn,
2073        )?))
2074    }
2075}
2076
2077impl IsqModelLoader for Gemma2Loader {
2078    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2079        Ok(vec![
2080            Regex::new(r"lm_head\.(weight|bias)$")?,
2081            // Attention
2082            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2083            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2084            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2085            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2086            // MLP
2087            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2088            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2089            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2090        ])
2091    }
2092}
2093
2094impl DeviceMappedModelLoader for Gemma2Loader {
2095    fn mapped_max_act_size_elems(
2096        &self,
2097        config: &str,
2098        params: &AutoDeviceMapParams,
2099        prompt_chunksize: usize,
2100    ) -> Result<usize> {
2101        let AutoDeviceMapParams::Text {
2102            max_seq_len: _,
2103            max_batch_size,
2104        } = params
2105        else {
2106            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2107        };
2108
2109        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2110
2111        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2112    }
2113    fn non_mapped_max_act_size_elems(
2114        &self,
2115        _config: &str,
2116        _params: &AutoDeviceMapParams,
2117    ) -> Result<usize> {
2118        Ok(0)
2119    }
2120
2121    fn non_mapped_size_in_bytes(
2122        &self,
2123        config: &str,
2124        dtype: DType,
2125        weight_pack_factor: usize,
2126    ) -> Result<usize> {
2127        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2128        let elems = {
2129            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2130            let lm_head = if !cfg.tie_word_embeddings {
2131                cfg.hidden_size * cfg.vocab_size
2132            } else {
2133                0
2134            };
2135            let norm = cfg.hidden_size;
2136            embed_tokens + lm_head + norm
2137        };
2138        Ok(elems * dtype.size_in_bytes())
2139    }
2140
2141    fn layer_sizes_in_bytes(
2142        &self,
2143        config: &str,
2144        dtype: DType,
2145        weight_pack_factor: usize,
2146    ) -> Result<Vec<usize>> {
2147        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2148        let per_layer_elems = {
2149            let input_layernorm = cfg.hidden_size;
2150            let post_attention_layernorm = cfg.hidden_size;
2151
2152            let size_in = cfg.hidden_size;
2153            let size_q = cfg.head_dim * cfg.num_attention_heads;
2154            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
2155            let q_proj =
2156                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2157            let k_proj =
2158                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2159            let v_proj =
2160                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2161            let o_proj =
2162                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2163
2164            let h_size = cfg.hidden_size;
2165            let i_size = cfg.intermediate_size;
2166            let gate_proj = h_size * i_size / weight_pack_factor;
2167            let up_proj = h_size * i_size / weight_pack_factor;
2168            let down_proj = i_size * h_size / weight_pack_factor;
2169
2170            input_layernorm
2171                + post_attention_layernorm
2172                + q_proj
2173                + k_proj
2174                + v_proj
2175                + o_proj
2176                + gate_proj
2177                + up_proj
2178                + down_proj
2179        };
2180        Ok(vec![
2181            per_layer_elems * dtype.size_in_bytes();
2182            cfg.num_hidden_layers
2183        ])
2184    }
2185
2186    fn num_layers(&self, config: &str) -> Result<usize> {
2187        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2188        Ok(cfg.num_hidden_layers)
2189    }
2190    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2191        let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2192
2193        let cfg = ModelConfigMetadata {
2194            max_seq_len: cfg.max_position_embeddings,
2195            num_layers: cfg.num_hidden_layers,
2196            hidden_size: cfg.hidden_size,
2197            num_kv_heads: cfg.num_key_value_heads,
2198            num_attn_heads: cfg.num_attention_heads,
2199            sliding_window: None, // None to be more forgiving, some do not
2200            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2201            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2202        };
2203
2204        Ok(Box::new(cfg))
2205    }
2206}
2207
2208// ======================== Starcoder2 loader
2209
2210#[derive(Deserialize, Debug)]
2211struct Starcoder2BasicConfig {
2212    vocab_size: usize,
2213    hidden_size: usize,
2214    intermediate_size: usize,
2215    num_hidden_layers: usize,
2216    num_attention_heads: usize,
2217    num_key_value_heads: usize,
2218    hidden_act: Activation,
2219    max_position_embeddings: usize,
2220    norm_epsilon: f64,
2221    rope_theta: f64,
2222    use_bias: bool,
2223    sliding_window: Option<usize>,
2224    quantization_config: Option<QuantizedConfig>,
2225    #[serde(default = "word_emb_default")]
2226    tie_word_embeddings: bool,
2227}
2228
2229impl Starcoder2BasicConfig {
2230    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::starcoder2::Config> {
2231        let basic_config: Self = serde_json::from_str(slice)?;
2232        Ok(models::starcoder2::Config {
2233            vocab_size: basic_config.vocab_size,
2234            hidden_size: basic_config.hidden_size,
2235            intermediate_size: basic_config.intermediate_size,
2236            num_hidden_layers: basic_config.num_hidden_layers,
2237            num_attention_heads: basic_config.num_attention_heads,
2238            num_key_value_heads: basic_config.num_key_value_heads,
2239            hidden_act: basic_config.hidden_act,
2240            max_position_embeddings: basic_config.max_position_embeddings,
2241            rope_theta: basic_config.rope_theta,
2242            sliding_window: basic_config.sliding_window,
2243            use_flash_attn,
2244            norm_epsilon: basic_config.norm_epsilon,
2245            use_bias: basic_config.use_bias,
2246            quantization_config: basic_config.quantization_config,
2247            tie_word_embeddings: basic_config.tie_word_embeddings,
2248        })
2249    }
2250}
2251
2252/// [`NormalLoader`] for a Starcoder2 model.
2253///
2254/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2255pub struct Starcoder2Loader;
2256
2257impl NormalModelLoader for Starcoder2Loader {
2258    fn load(
2259        &self,
2260        config: &str,
2261        use_flash_attn: bool,
2262        vb: ShardedVarBuilder,
2263        normal_loading_metadata: NormalLoadingMetadata,
2264        attention_mechanism: AttentionImplementation,
2265    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2266        Ok(Box::new(models::starcoder2::Model::new(
2267            &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2268            vb,
2269            self.is_gptx(config)?,
2270            normal_loading_metadata,
2271            attention_mechanism,
2272        )?))
2273    }
2274    fn load_xlora(
2275        &self,
2276        config: &str,
2277        use_flash_attn: bool,
2278        vb: ShardedVarBuilder,
2279        lora_config: &[((String, String), LoraConfig)],
2280        xlora_config: Option<XLoraConfig>,
2281        xlora_ordering: Ordering,
2282        normal_loading_metadata: NormalLoadingMetadata,
2283        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2284    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2285        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2286            &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2287            vb,
2288            lora_config,
2289            xlora_config,
2290            xlora_ordering,
2291            self.is_gptx(config)?,
2292            normal_loading_metadata,
2293            preload_adapters,
2294        )?))
2295    }
2296    fn is_gptx(&self, _: &str) -> Result<bool> {
2297        Ok(true)
2298    }
2299    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2300        Ok(Box::new(serde_json::from_str::<Starcoder2BasicConfig>(
2301            config,
2302        )?))
2303    }
2304}
2305
2306impl IsqModelLoader for Starcoder2Loader {
2307    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2308        Ok(vec![
2309            Regex::new(r"lm_head\.(weight|bias)$")?,
2310            // Attention
2311            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2312            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2313            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2314            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2315            // MLP
2316            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2317            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2318        ])
2319    }
2320}
2321
2322impl DeviceMappedModelLoader for Starcoder2Loader {
2323    fn mapped_max_act_size_elems(
2324        &self,
2325        config: &str,
2326        params: &AutoDeviceMapParams,
2327        prompt_chunksize: usize,
2328    ) -> Result<usize> {
2329        let AutoDeviceMapParams::Text {
2330            max_seq_len: _,
2331            max_batch_size,
2332        } = params
2333        else {
2334            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2335        };
2336
2337        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2338
2339        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2340    }
2341    fn non_mapped_max_act_size_elems(
2342        &self,
2343        _config: &str,
2344        _params: &AutoDeviceMapParams,
2345    ) -> Result<usize> {
2346        Ok(0)
2347    }
2348
2349    fn non_mapped_size_in_bytes(
2350        &self,
2351        config: &str,
2352        dtype: DType,
2353        weight_pack_factor: usize,
2354    ) -> Result<usize> {
2355        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2356        let elems = {
2357            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2358            let lm_head = if !cfg.tie_word_embeddings {
2359                cfg.hidden_size * cfg.vocab_size
2360            } else {
2361                0
2362            };
2363            let norm = cfg.hidden_size + cfg.hidden_size;
2364            embed_tokens + lm_head + norm
2365        };
2366        Ok(elems * dtype.size_in_bytes())
2367    }
2368
2369    fn layer_sizes_in_bytes(
2370        &self,
2371        config: &str,
2372        dtype: DType,
2373        weight_pack_factor: usize,
2374    ) -> Result<Vec<usize>> {
2375        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2376        let per_layer_elems = {
2377            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2378            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2379
2380            let size_in = cfg.hidden_size;
2381            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2382            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2383            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2384            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2385            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2386            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2387
2388            let h_size = cfg.hidden_size;
2389            let i_size = cfg.intermediate_size;
2390            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2391            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2392
2393            input_layernorm
2394                + post_attention_layernorm
2395                + q_proj
2396                + k_proj
2397                + v_proj
2398                + o_proj
2399                + fc1
2400                + fc2
2401        };
2402        Ok(vec![
2403            per_layer_elems * dtype.size_in_bytes();
2404            cfg.num_hidden_layers
2405        ])
2406    }
2407
2408    fn num_layers(&self, config: &str) -> Result<usize> {
2409        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2410        Ok(cfg.num_hidden_layers)
2411    }
2412
2413    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2414        let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2415
2416        let cfg = ModelConfigMetadata {
2417            max_seq_len: cfg.max_position_embeddings,
2418            num_layers: cfg.num_hidden_layers,
2419            hidden_size: cfg.hidden_size,
2420            num_kv_heads: cfg.num_key_value_heads,
2421            num_attn_heads: cfg.num_attention_heads,
2422            sliding_window: cfg.sliding_window,
2423            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2424            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2425        };
2426
2427        Ok(Box::new(cfg))
2428    }
2429}
2430
2431// ======================== Phi3 loader
2432
2433#[derive(Deserialize)]
2434struct Phi3_5MoEBasicConfig {
2435    vocab_size: usize,
2436    hidden_act: Activation,
2437    hidden_size: usize,
2438    intermediate_size: usize,
2439    num_hidden_layers: usize,
2440    num_attention_heads: usize,
2441    num_key_value_heads: usize,
2442    rms_norm_eps: f64,
2443    rope_theta: f64,
2444    rope_scaling: Option<PhiRopeScalingConfig>,
2445    max_position_embeddings: usize,
2446    original_max_position_embeddings: usize,
2447    sliding_window: Option<usize>,
2448    quantization_config: Option<QuantizedConfig>,
2449    lm_head_bias: bool,
2450    attention_bias: bool,
2451    num_local_experts: usize,
2452    router_jitter_noise: f64,
2453    #[serde(default = "word_emb_default")]
2454    tie_word_embeddings: bool,
2455}
2456
2457impl Phi3_5MoEBasicConfig {
2458    fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3_5_moe::Config> {
2459        let basic_config: Self = serde_json::from_str(slice)?;
2460        Ok(models::phi3_5_moe::Config {
2461            vocab_size: basic_config.vocab_size,
2462            hidden_size: basic_config.hidden_size,
2463            intermediate_size: basic_config.intermediate_size,
2464            num_hidden_layers: basic_config.num_hidden_layers,
2465            num_attention_heads: basic_config.num_attention_heads,
2466            num_key_value_heads: basic_config.num_key_value_heads,
2467            hidden_act: basic_config.hidden_act,
2468            max_position_embeddings: basic_config.max_position_embeddings,
2469            rope_theta: basic_config.rope_theta,
2470            rms_norm_eps: basic_config.rms_norm_eps,
2471            rope_scaling: basic_config.rope_scaling,
2472            original_max_position_embeddings: basic_config.original_max_position_embeddings,
2473            use_flash_attn,
2474            sliding_window: basic_config.sliding_window,
2475            quantization_config: basic_config.quantization_config,
2476            lm_head_bias: basic_config.lm_head_bias,
2477            attention_bias: basic_config.attention_bias,
2478            num_local_experts: basic_config.num_local_experts,
2479            router_jitter_noise: basic_config.router_jitter_noise,
2480            tie_word_embeddings: basic_config.tie_word_embeddings,
2481        })
2482    }
2483}
2484
2485/// [`NormalLoader`] for a Phi 3.5 MoE model.
2486///
2487/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2488pub struct Phi3_5MoELoader;
2489
2490impl NormalModelLoader for Phi3_5MoELoader {
2491    fn load(
2492        &self,
2493        config: &str,
2494        use_flash_attn: bool,
2495        vb: ShardedVarBuilder,
2496        normal_loading_metadata: NormalLoadingMetadata,
2497        attention_mechanism: AttentionImplementation,
2498    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2499        Ok(Box::new(models::phi3_5_moe::Model::new(
2500            &Phi3_5MoEBasicConfig::deserialize(config, use_flash_attn)?,
2501            vb,
2502            self.is_gptx(config)?,
2503            normal_loading_metadata,
2504            attention_mechanism,
2505        )?))
2506    }
2507    fn load_xlora(
2508        &self,
2509        config: &str,
2510        use_flash_attn: bool,
2511        vb: ShardedVarBuilder,
2512        lora_config: &[((String, String), LoraConfig)],
2513        xlora_config: Option<XLoraConfig>,
2514        xlora_ordering: Ordering,
2515        normal_loading_metadata: NormalLoadingMetadata,
2516        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2517    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2518        Ok(Box::new(xlora_models::XLoraPhi3::new(
2519            &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
2520            vb,
2521            lora_config,
2522            xlora_config,
2523            xlora_ordering,
2524            self.is_gptx(config)?,
2525            normal_loading_metadata,
2526            preload_adapters,
2527        )?))
2528    }
2529    fn is_gptx(&self, _: &str) -> Result<bool> {
2530        Ok(true)
2531    }
2532    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2533        Ok(Box::new(Phi3_5MoEBasicConfig::deserialize(
2534            config,
2535            use_flash_attn,
2536        )?))
2537    }
2538}
2539
2540impl IsqModelLoader for Phi3_5MoELoader {
2541    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2542        Ok(vec![
2543            Regex::new(r"lm_head\.(weight|bias)$")?,
2544            // Attention
2545            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2546            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2547            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2548            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2549            // MLP
2550            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2551            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2552            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2553        ])
2554    }
2555
2556    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2557        Ok(vec![
2558            Regex::new(r"lm_head\.(weight|bias)$")?,
2559            // MLP
2560            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2561            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2562            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2563        ])
2564    }
2565}
2566
2567impl DeviceMappedModelLoader for Phi3_5MoELoader {
2568    fn mapped_max_act_size_elems(
2569        &self,
2570        config: &str,
2571        params: &AutoDeviceMapParams,
2572        prompt_chunksize: usize,
2573    ) -> Result<usize> {
2574        let AutoDeviceMapParams::Text {
2575            max_seq_len: _,
2576            max_batch_size,
2577        } = params
2578        else {
2579            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2580        };
2581
2582        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2583
2584        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2585    }
2586    fn non_mapped_max_act_size_elems(
2587        &self,
2588        _config: &str,
2589        _params: &AutoDeviceMapParams,
2590    ) -> Result<usize> {
2591        Ok(0)
2592    }
2593
2594    fn non_mapped_size_in_bytes(
2595        &self,
2596        config: &str,
2597        dtype: DType,
2598        weight_pack_factor: usize,
2599    ) -> Result<usize> {
2600        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2601        let elems = {
2602            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2603            let lm_head = if !cfg.tie_word_embeddings {
2604                cfg.hidden_size * cfg.vocab_size
2605            } else {
2606                0
2607            };
2608            let norm = cfg.hidden_size;
2609            embed_tokens + lm_head + norm
2610        };
2611        Ok(elems * dtype.size_in_bytes())
2612    }
2613
2614    fn layer_sizes_in_bytes(
2615        &self,
2616        config: &str,
2617        dtype: DType,
2618        weight_pack_factor: usize,
2619    ) -> Result<Vec<usize>> {
2620        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2621        let per_layer_elems = {
2622            let input_layernorm = cfg.hidden_size;
2623            let post_attention_layernorm = cfg.hidden_size;
2624
2625            let size_in = cfg.hidden_size;
2626            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2627            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2628            let q_proj =
2629                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2630            let k_proj =
2631                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2632            let v_proj =
2633                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2634            let o_proj =
2635                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2636
2637            let moe_block = {
2638                let gate = cfg.hidden_size * cfg.num_local_experts;
2639                // Assume quantizing weight pack factor
2640                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2641                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2642                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2643                gate + cfg.num_local_experts * w1
2644                    + cfg.num_local_experts * w2
2645                    + cfg.num_local_experts * w3
2646            };
2647
2648            input_layernorm
2649                + post_attention_layernorm
2650                + q_proj
2651                + k_proj
2652                + v_proj
2653                + o_proj
2654                + moe_block
2655        };
2656        Ok(vec![
2657            per_layer_elems * dtype.size_in_bytes();
2658            cfg.num_hidden_layers
2659        ])
2660    }
2661
2662    fn num_layers(&self, config: &str) -> Result<usize> {
2663        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2664        Ok(cfg.num_hidden_layers)
2665    }
2666
2667    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2668        let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2669
2670        let cfg = ModelConfigMetadata {
2671            max_seq_len: cfg.max_position_embeddings,
2672            num_layers: cfg.num_hidden_layers,
2673            hidden_size: cfg.hidden_size,
2674            num_kv_heads: cfg.num_key_value_heads,
2675            num_attn_heads: cfg.num_attention_heads,
2676            sliding_window: cfg.sliding_window,
2677            k_head_dim: cfg.head_dim(),
2678            v_head_dim: cfg.head_dim(),
2679        };
2680
2681        Ok(Box::new(cfg))
2682    }
2683}
2684
2685/// [`NormalLoader`] for a DeepSeekV2 model.
2686///
2687/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
2688pub struct DeepSeekV2Loader;
2689
2690impl NormalModelLoader for DeepSeekV2Loader {
2691    fn load(
2692        &self,
2693        config: &str,
2694        use_flash_attn: bool,
2695        vb: ShardedVarBuilder,
2696        normal_loading_metadata: NormalLoadingMetadata,
2697        attention_mechanism: AttentionImplementation,
2698    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2699        let mut cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2700        cfg.use_flash_attn = use_flash_attn;
2701        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2702            &cfg,
2703            vb,
2704            self.is_gptx(config)?,
2705            normal_loading_metadata,
2706            attention_mechanism,
2707        )?))
2708    }
2709    fn load_xlora(
2710        &self,
2711        _config: &str,
2712        _use_flash_attn: bool,
2713        _vb: ShardedVarBuilder,
2714        _lora_config: &[((String, String), LoraConfig)],
2715        _xlora_config: Option<XLoraConfig>,
2716        _xlora_ordering: Ordering,
2717        _normal_loading_metadata: NormalLoadingMetadata,
2718        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2719    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2720        todo!()
2721    }
2722    fn is_gptx(&self, _: &str) -> Result<bool> {
2723        Ok(true)
2724    }
2725    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2726        let mut config: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2727        config.use_flash_attn = use_flash_attn;
2728        Ok(Box::new(config))
2729    }
2730}
2731
2732impl IsqModelLoader for DeepSeekV2Loader {
2733    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2734        let mut data = vec![
2735            Regex::new(r"lm_head\.(weight|bias)$")?,
2736            // Attention
2737            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2738            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2739            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2740        ];
2741        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2742        if cfg.q_lora_rank.is_some() {
2743            data.extend(vec![
2744                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2745                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2746            ]);
2747        } else {
2748            data.push(Regex::new(
2749                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2750            )?);
2751        }
2752        for layer_idx in 0..cfg.num_hidden_layers {
2753            if cfg.n_routed_experts.is_some()
2754                && layer_idx >= cfg.first_k_dense_replace
2755                && layer_idx % cfg.moe_layer_freq == 0
2756            {
2757                for i in 0..cfg.n_routed_experts.unwrap() {
2758                    data.extend(vec![
2759                        Regex::new(&format!(
2760                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2761                        ))?,
2762                        Regex::new(&format!(
2763                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2764                        ))?,
2765                        Regex::new(&format!(
2766                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2767                        ))?,
2768                    ]);
2769                }
2770                if cfg.n_shared_experts.is_some() {
2771                    data.extend(vec![
2772                        Regex::new(&format!(
2773                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2774                        ))?,
2775                        Regex::new(&format!(
2776                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2777                        ))?,
2778                        Regex::new(&format!(
2779                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2780                        ))?,
2781                    ]);
2782                }
2783            } else {
2784                data.extend(vec![
2785                    Regex::new(&format!(
2786                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2787                    ))?,
2788                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2789                    Regex::new(&format!(
2790                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2791                    ))?,
2792                ]);
2793            };
2794        }
2795        Ok(data)
2796    }
2797
2798    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2799        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2800        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2801        for layer_idx in 0..cfg.num_hidden_layers {
2802            if cfg.n_routed_experts.is_some()
2803                && layer_idx >= cfg.first_k_dense_replace
2804                && layer_idx % cfg.moe_layer_freq == 0
2805            {
2806                for i in 0..cfg.n_routed_experts.unwrap() {
2807                    data.extend(vec![
2808                        Regex::new(&format!(
2809                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2810                        ))?,
2811                        Regex::new(&format!(
2812                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2813                        ))?,
2814                        Regex::new(&format!(
2815                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2816                        ))?,
2817                    ]);
2818                }
2819                if cfg.n_shared_experts.is_some() {
2820                    data.extend(vec![
2821                        Regex::new(&format!(
2822                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2823                        ))?,
2824                        Regex::new(&format!(
2825                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2826                        ))?,
2827                        Regex::new(&format!(
2828                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2829                        ))?,
2830                    ]);
2831                }
2832            } else {
2833                data.extend(vec![
2834                    Regex::new(&format!(
2835                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2836                    ))?,
2837                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2838                    Regex::new(&format!(
2839                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2840                    ))?,
2841                ]);
2842            };
2843        }
2844        Ok(data)
2845    }
2846}
2847
2848impl DeviceMappedModelLoader for DeepSeekV2Loader {
2849    fn mapped_max_act_size_elems(
2850        &self,
2851        config: &str,
2852        params: &AutoDeviceMapParams,
2853        prompt_chunksize: usize,
2854    ) -> Result<usize> {
2855        let AutoDeviceMapParams::Text {
2856            max_seq_len: _,
2857            max_batch_size,
2858        } = params
2859        else {
2860            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2861        };
2862
2863        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2864
2865        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2866    }
2867    fn non_mapped_max_act_size_elems(
2868        &self,
2869        _config: &str,
2870        _params: &AutoDeviceMapParams,
2871    ) -> Result<usize> {
2872        Ok(0)
2873    }
2874
2875    fn non_mapped_size_in_bytes(
2876        &self,
2877        config: &str,
2878        dtype: DType,
2879        weight_pack_factor: usize,
2880    ) -> Result<usize> {
2881        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2882        let elems = {
2883            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2884            let lm_head = if !cfg.tie_word_embeddings {
2885                cfg.hidden_size * cfg.vocab_size
2886            } else {
2887                0
2888            };
2889            let norm = cfg.hidden_size;
2890            embed_tokens + lm_head + norm
2891        };
2892        Ok(elems * dtype.size_in_bytes())
2893    }
2894
2895    fn layer_sizes_in_bytes(
2896        &self,
2897        config: &str,
2898        dtype: DType,
2899        weight_pack_factor: usize,
2900    ) -> Result<Vec<usize>> {
2901        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2902        let mut per_layer_elems = Vec::new();
2903
2904        for layer_idx in 0..cfg.num_hidden_layers {
2905            let input_layernorm = cfg.hidden_size;
2906            let post_attention_layernorm = cfg.hidden_size;
2907
2908            let q_proj = match cfg.q_lora_rank {
2909                Some(lora_rank) => {
2910                    let a = cfg.hidden_size * lora_rank;
2911                    let norm = lora_rank;
2912                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2913                    a + norm + b
2914                }
2915                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2916            };
2917            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2918                / weight_pack_factor
2919                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2920            let kv_a_layernorm = cfg.kv_lora_rank;
2921            let kv_b_proj = cfg.kv_lora_rank
2922                * cfg.num_attention_heads
2923                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2924                / weight_pack_factor;
2925            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2926                / weight_pack_factor
2927                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2928
2929            let moe_block = {
2930                let mut sum = 0;
2931                if cfg.n_routed_experts.is_some()
2932                    && layer_idx >= cfg.first_k_dense_replace
2933                    && layer_idx % cfg.moe_layer_freq == 0
2934                {
2935                    let h_size = cfg.hidden_size;
2936                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2937                        * cfg.n_routed_experts.unwrap();
2938                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2939                        * cfg.n_routed_experts.unwrap();
2940                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2941                        * cfg.n_routed_experts.unwrap();
2942                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2943                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2944                            / weight_pack_factor;
2945                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2946                            / weight_pack_factor;
2947                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2948                            / weight_pack_factor;
2949                        gate_proj + up_proj + down_proj
2950                    } else {
2951                        0
2952                    };
2953                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2954                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2955                } else {
2956                    let h_size = cfg.hidden_size;
2957                    let i_size = cfg.intermediate_size;
2958                    let gate_proj = h_size * i_size / weight_pack_factor;
2959                    let up_proj = h_size * i_size / weight_pack_factor;
2960                    let down_proj = i_size * h_size / weight_pack_factor;
2961                    sum += gate_proj + up_proj + down_proj;
2962                }
2963                sum
2964            };
2965
2966            per_layer_elems.push(
2967                input_layernorm
2968                    + post_attention_layernorm
2969                    + q_proj
2970                    + kv_a_layernorm
2971                    + kv_a_proj_with_mqa
2972                    + kv_b_proj
2973                    + o_proj
2974                    + moe_block,
2975            );
2976        }
2977
2978        Ok(per_layer_elems
2979            .into_iter()
2980            .map(|x| x * dtype.size_in_bytes())
2981            .collect())
2982    }
2983
2984    fn num_layers(&self, config: &str) -> Result<usize> {
2985        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2986        Ok(cfg.num_hidden_layers)
2987    }
2988
2989    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2990        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2991
2992        let cfg = ModelConfigMetadata {
2993            max_seq_len: cfg.max_position_embeddings,
2994            num_layers: cfg.num_hidden_layers,
2995            hidden_size: cfg.hidden_size,
2996            num_kv_heads: cfg.num_attention_heads,
2997            num_attn_heads: cfg.num_attention_heads,
2998            sliding_window: None,
2999            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3000            v_head_dim: cfg.v_head_dim,
3001        };
3002
3003        Ok(Box::new(cfg))
3004    }
3005}
3006
3007/// [`NormalLoader`] for a DeepSeekV3 model.
3008///
3009/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3010pub struct DeepSeekV3Loader;
3011
3012impl NormalModelLoader for DeepSeekV3Loader {
3013    fn load(
3014        &self,
3015        config: &str,
3016        use_flash_attn: bool,
3017        vb: ShardedVarBuilder,
3018        normal_loading_metadata: NormalLoadingMetadata,
3019        attention_mechanism: AttentionImplementation,
3020    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3021        let mut cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3022        cfg.use_flash_attn = use_flash_attn;
3023        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
3024            &cfg,
3025            vb,
3026            self.is_gptx(config)?,
3027            normal_loading_metadata,
3028            attention_mechanism,
3029        )?))
3030    }
3031    fn load_xlora(
3032        &self,
3033        _config: &str,
3034        _use_flash_attn: bool,
3035        _vb: ShardedVarBuilder,
3036        _lora_config: &[((String, String), LoraConfig)],
3037        _xlora_config: Option<XLoraConfig>,
3038        _xlora_ordering: Ordering,
3039        _normal_loading_metadata: NormalLoadingMetadata,
3040        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3041    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3042        todo!()
3043    }
3044    fn is_gptx(&self, _: &str) -> Result<bool> {
3045        Ok(true)
3046    }
3047    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3048        let mut config: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3049        config.use_flash_attn = use_flash_attn;
3050        Ok(Box::new(config))
3051    }
3052}
3053
3054impl IsqModelLoader for DeepSeekV3Loader {
3055    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3056        let mut data = vec![
3057            Regex::new(r"lm_head\.(weight|bias)$")?,
3058            // Attention
3059            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3060            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3061            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3062        ];
3063        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3064        if cfg.q_lora_rank.is_some() {
3065            data.extend(vec![
3066                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3067                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3068            ]);
3069        } else {
3070            data.push(Regex::new(
3071                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
3072            )?);
3073        }
3074        for layer_idx in 0..cfg.num_hidden_layers {
3075            if cfg.n_routed_experts.is_some()
3076                && layer_idx >= cfg.first_k_dense_replace
3077                && layer_idx % cfg.moe_layer_freq == 0
3078            {
3079                for i in 0..cfg.n_routed_experts.unwrap() {
3080                    data.extend(vec![
3081                        Regex::new(&format!(
3082                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3083                        ))?,
3084                        Regex::new(&format!(
3085                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3086                        ))?,
3087                        Regex::new(&format!(
3088                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3089                        ))?,
3090                    ]);
3091                }
3092                if cfg.n_shared_experts.is_some() {
3093                    data.extend(vec![
3094                        Regex::new(&format!(
3095                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3096                        ))?,
3097                        Regex::new(&format!(
3098                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3099                        ))?,
3100                        Regex::new(&format!(
3101                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3102                        ))?,
3103                    ]);
3104                }
3105            } else {
3106                data.extend(vec![
3107                    Regex::new(&format!(
3108                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3109                    ))?,
3110                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3111                    Regex::new(&format!(
3112                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3113                    ))?,
3114                ]);
3115            };
3116        }
3117        Ok(data)
3118    }
3119
3120    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3121        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3122        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3123        for layer_idx in 0..cfg.num_hidden_layers {
3124            if cfg.n_routed_experts.is_some()
3125                && layer_idx >= cfg.first_k_dense_replace
3126                && layer_idx % cfg.moe_layer_freq == 0
3127            {
3128                for i in 0..cfg.n_routed_experts.unwrap() {
3129                    data.extend(vec![
3130                        Regex::new(&format!(
3131                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3132                        ))?,
3133                        Regex::new(&format!(
3134                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3135                        ))?,
3136                        Regex::new(&format!(
3137                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3138                        ))?,
3139                    ]);
3140                }
3141                if cfg.n_shared_experts.is_some() {
3142                    data.extend(vec![
3143                        Regex::new(&format!(
3144                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3145                        ))?,
3146                        Regex::new(&format!(
3147                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3148                        ))?,
3149                        Regex::new(&format!(
3150                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3151                        ))?,
3152                    ]);
3153                }
3154            } else {
3155                data.extend(vec![
3156                    Regex::new(&format!(
3157                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3158                    ))?,
3159                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3160                    Regex::new(&format!(
3161                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3162                    ))?,
3163                ]);
3164            };
3165        }
3166        Ok(data)
3167    }
3168}
3169
3170impl DeviceMappedModelLoader for DeepSeekV3Loader {
3171    fn mapped_max_act_size_elems(
3172        &self,
3173        config: &str,
3174        params: &AutoDeviceMapParams,
3175        prompt_chunksize: usize,
3176    ) -> Result<usize> {
3177        let AutoDeviceMapParams::Text {
3178            max_seq_len: _,
3179            max_batch_size,
3180        } = params
3181        else {
3182            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3183        };
3184
3185        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3186
3187        Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3188    }
3189    fn non_mapped_max_act_size_elems(
3190        &self,
3191        _config: &str,
3192        _params: &AutoDeviceMapParams,
3193    ) -> Result<usize> {
3194        Ok(0)
3195    }
3196
3197    fn non_mapped_size_in_bytes(
3198        &self,
3199        config: &str,
3200        dtype: DType,
3201        weight_pack_factor: usize,
3202    ) -> Result<usize> {
3203        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3204        let elems = {
3205            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3206            let lm_head = if !cfg.tie_word_embeddings {
3207                cfg.hidden_size * cfg.vocab_size
3208            } else {
3209                0
3210            };
3211            let norm = cfg.hidden_size;
3212            embed_tokens + lm_head + norm
3213        };
3214        Ok(elems * dtype.size_in_bytes())
3215    }
3216
3217    fn layer_sizes_in_bytes(
3218        &self,
3219        config: &str,
3220        dtype: DType,
3221        weight_pack_factor: usize,
3222    ) -> Result<Vec<usize>> {
3223        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3224        let mut per_layer_elems = Vec::new();
3225
3226        for layer_idx in 0..cfg.num_hidden_layers {
3227            let input_layernorm = cfg.hidden_size;
3228            let post_attention_layernorm = cfg.hidden_size;
3229
3230            let q_proj = match cfg.q_lora_rank {
3231                Some(lora_rank) => {
3232                    let a = cfg.hidden_size * lora_rank;
3233                    let norm = lora_rank;
3234                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
3235                    a + norm + b
3236                }
3237                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
3238            };
3239            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3240                / weight_pack_factor
3241                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3242            let kv_a_layernorm = cfg.kv_lora_rank;
3243            let kv_b_proj = cfg.kv_lora_rank
3244                * cfg.num_attention_heads
3245                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3246                / weight_pack_factor;
3247            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3248                / weight_pack_factor
3249                + bias_if!(cfg.attention_bias, cfg.hidden_size);
3250
3251            let moe_block = {
3252                let mut sum = 0;
3253                if cfg.n_routed_experts.is_some()
3254                    && layer_idx >= cfg.first_k_dense_replace
3255                    && layer_idx % cfg.moe_layer_freq == 0
3256                {
3257                    let h_size = cfg.hidden_size;
3258                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3259                        * cfg.n_routed_experts.unwrap();
3260                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3261                        * cfg.n_routed_experts.unwrap();
3262                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3263                        * cfg.n_routed_experts.unwrap();
3264                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3265                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3266                            / weight_pack_factor;
3267                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3268                            / weight_pack_factor;
3269                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3270                            / weight_pack_factor;
3271                        gate_proj + up_proj + down_proj
3272                    } else {
3273                        0
3274                    };
3275                    let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3276                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3277                } else {
3278                    let h_size = cfg.hidden_size;
3279                    let i_size = cfg.intermediate_size;
3280                    let gate_proj = h_size * i_size / weight_pack_factor;
3281                    let up_proj = h_size * i_size / weight_pack_factor;
3282                    let down_proj = i_size * h_size / weight_pack_factor;
3283                    sum += gate_proj + up_proj + down_proj;
3284                }
3285                sum
3286            };
3287
3288            per_layer_elems.push(
3289                input_layernorm
3290                    + post_attention_layernorm
3291                    + q_proj
3292                    + kv_a_layernorm
3293                    + kv_a_proj_with_mqa
3294                    + kv_b_proj
3295                    + o_proj
3296                    + moe_block,
3297            );
3298        }
3299
3300        Ok(per_layer_elems
3301            .into_iter()
3302            .map(|x| x * dtype.size_in_bytes())
3303            .collect())
3304    }
3305
3306    fn num_layers(&self, config: &str) -> Result<usize> {
3307        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3308        Ok(cfg.num_hidden_layers)
3309    }
3310
3311    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3312        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3313
3314        let cfg = ModelConfigMetadata {
3315            max_seq_len: cfg.max_position_embeddings,
3316            num_layers: cfg.num_hidden_layers,
3317            hidden_size: cfg.hidden_size,
3318            num_kv_heads: cfg.num_attention_heads,
3319            num_attn_heads: cfg.num_attention_heads,
3320            sliding_window: None,
3321            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3322            v_head_dim: cfg.v_head_dim,
3323        };
3324
3325        Ok(Box::new(cfg))
3326    }
3327}