1use std::{
2    fmt::{self, Debug, Display},
3    path::PathBuf,
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    attention::ATTENTION_CHUNK_SIZE,
10    embedding_models::{
11        embedding_gemma::{EmbeddingGemma, EmbeddingGemmaConfig},
12        qwen3_embedding::{Config as Qwen3EmbeddingConfig, Model as Qwen3EmbeddingModel},
13    },
14    matformer::MatformerSliceConfig,
15    pipeline::{loaders::auto_device_map::NonMappedSubModel, NormalLoadingMetadata},
16};
17
18use crate::{
19    amoe::AnyMoeBaseModelMixin,
20    device_map::DeviceMapper,
21    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
22    pipeline::{isq::IsqModelLoader, text_models_inputs_processor::FlashParams, IsqModel},
23    utils::varbuilder_utils::DeviceForLoadTensor,
24};
25use anyhow::Result;
26use candle_core::{DType, Device, Tensor};
27use mistralrs_quant::log::once_log_info;
28
29use mistralrs_quant::ShardedVarBuilder;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
35
36use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
37
38pub trait EmbeddingModel: IsqModel + AnyMoeBaseModelMixin {
39    #[allow(clippy::too_many_arguments)]
40    fn forward(
41        &self,
42        input_ids: &Tensor,
43        flash_params: &FlashParams,
44    ) -> candle_core::Result<Tensor>;
45    fn device(&self) -> &Device;
46}
47
48pub trait EmbeddingModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
49    fn load(
50        &self,
51        config: &str,
52        vb: ShardedVarBuilder,
53        normal_loading_metadata: NormalLoadingMetadata,
54        attention_mechanism: AttentionImplementation,
55    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>>;
56    fn is_gptx(&self, config: &str) -> Result<bool>;
57    fn has_causal_attention(&self, config: &str) -> Result<bool>;
58    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
59    fn get_device_for_tensor(
60        &self,
61        config: &str,
62        _mapper: &dyn DeviceMapper,
63        loading_isq: bool,
64    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
65        if loading_isq {
66            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
67        } else {
68            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
69            let num_layers = self.model_config(config)?.num_layers();
70            let closure = move |name: String| {
71                if let Some(captures) = re.captures(&name) {
72                    captures
73                        .get(1)
74                        .and_then(|m| m.as_str().parse::<usize>().ok())
75                        .map(|l| l.min(num_layers))
76                        .map(DeviceForLoadTensor::Idx)
77                        .unwrap_or(DeviceForLoadTensor::Base)
78                } else {
79                    DeviceForLoadTensor::Base
80                }
81            };
82
83            Ok(Arc::new(closure))
84        }
85    }
86}
87
88#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
89#[derive(Clone, Debug, Deserialize, PartialEq)]
90pub enum EmbeddingLoaderType {
92    #[serde(rename = "embeddinggemma")]
93    EmbeddingGemma,
94    #[serde(rename = "qwen3embedding")]
95    Qwen3Embedding,
96}
97
98impl EmbeddingLoaderType {
100    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
101        match name {
102            "Gemma3TextModel" => Ok(Self::EmbeddingGemma),
103            "Qwen3ForCausalLM" => Ok(Self::Qwen3Embedding),
104            other => anyhow::bail!(
105                "Unsupported Hugging Face Transformers model class `{other}`. Please raise an issue."
106            ),
107        }
108    }
109}
110
111impl FromStr for EmbeddingLoaderType {
112    type Err = String;
113    fn from_str(s: &str) -> Result<Self, Self::Err> {
114        match s {
115            "embeddinggemma" => Ok(Self::EmbeddingGemma),
116            "qwen3embedding" => Ok(Self::Qwen3Embedding),
117            a => Err(format!(
118                "Unknown architecture `{a}`. Possible architectures: `embeddinggemma`, `qwen3embedding`."
119            )),
120        }
121    }
122}
123
124impl Display for EmbeddingLoaderType {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Self::EmbeddingGemma => write!(f, "embeddinggemma"),
128            Self::Qwen3Embedding => write!(f, "qwen3embedding"),
129        }
130    }
131}
132
133#[derive(Clone, Debug, Deserialize)]
134pub enum EmbeddingModulePaths {
135    Transformer {
136        path: String,
137    },
138    Pooling {
139        path: String,
140        config: PathBuf,
141    },
142    Dense {
143        path: String,
144        config: PathBuf,
145        model: PathBuf,
146    },
147    Normalize {
148        path: String,
149    },
150}
151
152impl EmbeddingModulePaths {
153    pub fn serialize_modules(modules: &[EmbeddingModulePaths]) -> String {
154        #[derive(Serialize)]
155        struct OutputModule {
156            idx: usize,
157            name: String,
158            path: String,
159            #[serde(rename = "type")]
160            ty: String,
161        }
162
163        let mapped: Vec<OutputModule> = modules
164            .iter()
165            .enumerate()
166            .map(|(i, m)| {
167                let (path, ty) = match m {
168                    EmbeddingModulePaths::Transformer { path } => (
169                        path.clone(),
170                        "sentence_transformers.models.Transformer".to_string(),
171                    ),
172                    EmbeddingModulePaths::Pooling { path, .. } => (
173                        path.clone(),
174                        "sentence_transformers.models.Pooling".to_string(),
175                    ),
176                    EmbeddingModulePaths::Dense { path, .. } => (
177                        path.clone(),
178                        "sentence_transformers.models.Dense".to_string(),
179                    ),
180                    EmbeddingModulePaths::Normalize { path } => (
181                        path.clone(),
182                        "sentence_transformers.models.Normalize".to_string(),
183                    ),
184                };
185
186                OutputModule {
187                    idx: i,
188                    name: i.to_string(),
189                    path,
190                    ty,
191                }
192            })
193            .collect();
194
195        serde_json::to_string_pretty(&mapped).unwrap()
196    }
197}
198
199#[derive(Debug, Deserialize)]
200pub struct EmbeddingModule {
201    pub path: String,
202    #[serde(rename = "type", deserialize_with = "deserialize_module_type")]
203    pub ty: EmbeddingModuleType,
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum EmbeddingModuleType {
208    Transformer,
209    Pooling,
210    Dense,
211    Normalize,
212}
213
214fn deserialize_module_type<'de, D>(deserializer: D) -> Result<EmbeddingModuleType, D::Error>
215where
216    D: Deserializer<'de>,
217{
218    struct ModuleTypeVisitor;
219
220    impl<'de> Visitor<'de> for ModuleTypeVisitor {
221        type Value = EmbeddingModuleType;
222
223        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
224            f.write_str("a sentence-transformers module type string")
225        }
226
227        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
228        where
229            E: serde::de::Error,
230        {
231            let last = v.rsplit('.').next().unwrap_or(v).to_ascii_lowercase();
233            match last.as_str() {
234                "transformer" => Ok(EmbeddingModuleType::Transformer),
235                "pooling" => Ok(EmbeddingModuleType::Pooling),
236                "dense" => Ok(EmbeddingModuleType::Dense),
237                "normalize" => Ok(EmbeddingModuleType::Normalize),
238                _ => Err(E::invalid_value(
239                    serde::de::Unexpected::Str(v),
240                    &"Transformer/Pooling/Dense/Normalize",
241                )),
242            }
243        }
244    }
245
246    deserializer.deserialize_str(ModuleTypeVisitor)
247}
248
249macro_rules! bias_if {
250    ($cond:expr, $size:expr) => {
251        if $cond {
252            $size
253        } else {
254            0
255        }
256    };
257}
258
259pub struct AutoEmbeddingLoader;
261
262#[derive(Deserialize)]
263struct AutoEmbeddingLoaderConfig {
264    architectures: Vec<String>,
265}
266
267impl AutoEmbeddingLoader {
268    fn get_loader(config: &str) -> Result<Box<dyn EmbeddingModelLoader>> {
269        let auto_cfg: AutoEmbeddingLoaderConfig = serde_json::from_str(config)?;
270        if auto_cfg.architectures.len() != 1 {
271            anyhow::bail!("Expected to have one name for `architectures` config field.")
272        }
273
274        let name = &auto_cfg.architectures[0];
275
276        let tp = EmbeddingLoaderType::from_causal_lm_name(name)?;
277
278        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
279
280        match tp {
281            EmbeddingLoaderType::EmbeddingGemma => Ok(Box::new(EmbeddingGemmaLoader)),
282            EmbeddingLoaderType::Qwen3Embedding => Ok(Box::new(Qwen3EmbeddingLoader)),
283        }
284    }
285}
286
287impl EmbeddingModelLoader for AutoEmbeddingLoader {
288    fn load(
289        &self,
290        config: &str,
291        vb: ShardedVarBuilder,
292        normal_loading_metadata: NormalLoadingMetadata,
293        attention_mechanism: AttentionImplementation,
294    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
295        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
296    }
297    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
298        Self::get_loader(config)?.get_config_repr(config)
299    }
300    fn has_causal_attention(&self, config: &str) -> Result<bool> {
301        Self::get_loader(config)?.has_causal_attention(config)
302    }
303    fn is_gptx(&self, config: &str) -> Result<bool> {
304        Self::get_loader(config)?.is_gptx(config)
305    }
306}
307
308impl IsqModelLoader for AutoEmbeddingLoader {
309    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
310        Self::get_loader(config)?.immediate_isq_predicates(config)
311    }
312    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
313        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
314    }
315    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
316        Self::get_loader(config)?.isq_layer_regexes(config)
317    }
318    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
319        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
320    }
321}
322
323impl DeviceMappedModelLoader for AutoEmbeddingLoader {
324    fn non_mapped_size_in_bytes(
325        &self,
326        config: &str,
327        dtype: DType,
328        weight_pack_factor: usize,
329        _matformer_config: Option<&MatformerSliceConfig>,
330    ) -> Result<usize> {
331        Self::get_loader(config)?.non_mapped_size_in_bytes(
332            config,
333            dtype,
334            weight_pack_factor,
335            _matformer_config,
336        )
337    }
338    fn num_layers(&self, config: &str) -> Result<usize> {
339        Self::get_loader(config)?.num_layers(config)
340    }
341    fn layer_sizes_in_bytes(
342        &self,
343        config: &str,
344        dtype: DType,
345        weight_pack_factor: usize,
346        _matformer_config: Option<&MatformerSliceConfig>,
347    ) -> Result<Vec<usize>> {
348        Self::get_loader(config)?.layer_sizes_in_bytes(
349            config,
350            dtype,
351            weight_pack_factor,
352            _matformer_config,
353        )
354    }
355    fn mapped_max_act_size_elems(
356        &self,
357        config: &str,
358        params: &super::AutoDeviceMapParams,
359    ) -> Result<usize> {
360        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
361    }
362    fn non_mapped_max_act_size_elems(
363        &self,
364        _config: &str,
365        _params: &AutoDeviceMapParams,
366    ) -> Result<usize> {
367        Ok(0)
368    }
369    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
370        Self::get_loader(config)?.model_config(config)
371    }
372}
373
374pub struct EmbeddingGemmaLoader;
378
379impl EmbeddingModelLoader for EmbeddingGemmaLoader {
380    fn load(
381        &self,
382        config: &str,
383        vb: ShardedVarBuilder,
384        normal_loading_metadata: NormalLoadingMetadata,
385        attention_mechanism: AttentionImplementation,
386    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
387        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
388
389        Ok(Box::new(EmbeddingGemma::new(
390            &cfg,
391            vb,
392            self.is_gptx(config)?,
393            normal_loading_metadata,
394            attention_mechanism,
395        )?))
396    }
397    fn is_gptx(&self, _: &str) -> Result<bool> {
398        Ok(true)
399    }
400    fn has_causal_attention(&self, _: &str) -> Result<bool> {
401        Ok(false)
402    }
403    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
404        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
405        Ok(Box::new(cfg))
406    }
407}
408
409impl IsqModelLoader for EmbeddingGemmaLoader {
410    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
411        Ok(vec![
412            Regex::new(r"lm_head\.(weight|bias)$")?,
413            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
415            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
416            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
417            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
418            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
420            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
421            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
422        ])
423    }
424    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
425        Ok(vec![
426            Regex::new(r"lm_head\.(weight|bias)$")?,
427            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
429            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
430            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
431            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
432            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
434            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
435            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
436        ])
437    }
438}
439
440impl DeviceMappedModelLoader for EmbeddingGemmaLoader {
441    fn mapped_max_act_size_elems(
442        &self,
443        config: &str,
444        params: &AutoDeviceMapParams,
445    ) -> Result<usize> {
446        let AutoDeviceMapParams::Text {
447            max_seq_len,
448            max_batch_size,
449        } = params
450        else {
451            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
452        };
453
454        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
455
456        Ok(
457            max_batch_size
458                * cfg.num_attention_heads
459                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
460        )
461    }
462
463    fn non_mapped_max_act_size_elems(
464        &self,
465        _config: &str,
466        _params: &AutoDeviceMapParams,
467    ) -> Result<usize> {
468        Ok(0)
469    }
470
471    fn non_mapped_size_in_bytes(
472        &self,
473        config: &str,
474        dtype: DType,
475        weight_pack_factor: usize,
476        _matformer_config: Option<&MatformerSliceConfig>,
477    ) -> Result<usize> {
478        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
479
480        let elems = {
481            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
482            let norm = cfg.hidden_size;
483            embed_tokens + norm
484        };
485        Ok(elems * dtype.size_in_bytes())
486    }
487
488    fn layer_sizes_in_bytes(
489        &self,
490        config: &str,
491        dtype: DType,
492        weight_pack_factor: usize,
493        _matformer_config: Option<&MatformerSliceConfig>,
494    ) -> Result<Vec<usize>> {
495        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
496
497        let per_layer_elems = {
498            let input_layernorm = cfg.hidden_size;
499            let post_attention_layernorm = cfg.hidden_size;
500
501            let size_in = cfg.hidden_size;
502            let size_q = cfg.head_dim * cfg.num_attention_heads;
503            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
504            let q_proj =
505                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
506            let k_proj =
507                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
508            let v_proj =
509                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
510            let o_proj =
511                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
512
513            let h_size = cfg.hidden_size;
514            let i_size = cfg.intermediate_size;
515            let gate_proj = h_size * i_size / weight_pack_factor;
516            let up_proj = h_size * i_size / weight_pack_factor;
517            let down_proj = i_size * h_size / weight_pack_factor;
518
519            input_layernorm
520                + post_attention_layernorm
521                + q_proj
522                + k_proj
523                + v_proj
524                + o_proj
525                + gate_proj
526                + up_proj
527                + down_proj
528        };
529        Ok(vec![
530            per_layer_elems * dtype.size_in_bytes();
531            cfg.num_hidden_layers
532        ])
533    }
534
535    fn num_layers(&self, config: &str) -> Result<usize> {
536        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
537
538        Ok(cfg.num_hidden_layers)
539    }
540
541    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
542        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
543
544        let cfg = ModelConfigMetadata {
545            max_seq_len: cfg.max_position_embeddings,
546            num_layers: cfg.num_hidden_layers,
547            hidden_size: cfg.hidden_size,
548            num_kv_heads: cfg.num_key_value_heads,
549            num_attn_heads: cfg.num_attention_heads,
550            sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
552            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
553        };
554
555        Ok(Box::new(cfg))
556    }
557
558    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
559        None }
561}
562
563pub struct Qwen3EmbeddingLoader;
567
568impl EmbeddingModelLoader for Qwen3EmbeddingLoader {
569    fn load(
570        &self,
571        config: &str,
572        vb: ShardedVarBuilder,
573        normal_loading_metadata: NormalLoadingMetadata,
574        attention_mechanism: AttentionImplementation,
575    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
576        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
577
578        Ok(Box::new(Qwen3EmbeddingModel::new(
579            &cfg,
580            vb,
581            self.is_gptx(config)?,
582            normal_loading_metadata,
583            attention_mechanism,
584        )?))
585    }
586    fn has_causal_attention(&self, _: &str) -> Result<bool> {
587        Ok(true)
588    }
589    fn is_gptx(&self, _: &str) -> Result<bool> {
590        Ok(true)
591    }
592    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
593        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
594
595        Ok(Box::new(cfg))
596    }
597}
598
599impl IsqModelLoader for Qwen3EmbeddingLoader {
600    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
601        Ok(vec![
602            Regex::new(r"lm_head\.(weight|bias)$")?,
603            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
605            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
606            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
607            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
608            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
610            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
611            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
612        ])
613    }
614    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
615        self.isq_layer_regexes(config)
616    }
617}
618
619impl DeviceMappedModelLoader for Qwen3EmbeddingLoader {
620    fn mapped_max_act_size_elems(
621        &self,
622        config: &str,
623        params: &AutoDeviceMapParams,
624    ) -> Result<usize> {
625        let AutoDeviceMapParams::Text {
626            max_seq_len,
627            max_batch_size,
628        } = params
629        else {
630            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
631        };
632
633        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
634
635        Ok(
636            max_batch_size
637                * cfg.num_attention_heads
638                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
639        )
640    }
641    fn non_mapped_max_act_size_elems(
642        &self,
643        _config: &str,
644        _params: &AutoDeviceMapParams,
645    ) -> Result<usize> {
646        Ok(0)
647    }
648
649    fn non_mapped_size_in_bytes(
650        &self,
651        config: &str,
652        dtype: DType,
653        weight_pack_factor: usize,
654        _matformer_config: Option<&MatformerSliceConfig>,
655    ) -> Result<usize> {
656        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
657        let elems = {
658            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
659            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
661                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
662            } else {
663                0
664            };
665            let norm = cfg.hidden_size;
666            embed_tokens + lm_head + norm
667        };
668        Ok(elems * dtype.size_in_bytes())
669    }
670
671    fn layer_sizes_in_bytes(
672        &self,
673        config: &str,
674        dtype: DType,
675        weight_pack_factor: usize,
676        _matformer_config: Option<&MatformerSliceConfig>,
677    ) -> Result<Vec<usize>> {
678        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
679        let per_layer_elems = {
680            let input_layernorm = cfg.hidden_size;
681            let post_attention_layernorm = cfg.hidden_size;
682
683            let size_in = cfg.hidden_size;
684            let size_q = cfg.head_dim() * cfg.num_attention_heads;
685            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
686            let q_proj = size_in * size_q / weight_pack_factor + size_q;
687            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
688            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
689            let o_proj = size_q * size_in / weight_pack_factor;
690
691            let h_size = cfg.hidden_size;
692            let i_size = cfg.intermediate_size;
693            let gate_proj = h_size * i_size / weight_pack_factor;
694            let up_proj = h_size * i_size / weight_pack_factor;
695            let down_proj = i_size * h_size / weight_pack_factor;
696
697            let q_norm = cfg.head_dim();
698            let k_norm = cfg.head_dim();
699
700            input_layernorm
701                + post_attention_layernorm
702                + q_proj
703                + k_proj
704                + v_proj
705                + o_proj
706                + gate_proj
707                + up_proj
708                + down_proj
709                + q_norm
710                + k_norm
711        };
712        Ok(vec![
713            per_layer_elems * dtype.size_in_bytes();
714            cfg.num_hidden_layers
715        ])
716    }
717
718    fn num_layers(&self, config: &str) -> Result<usize> {
719        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
720        Ok(cfg.num_hidden_layers)
721    }
722
723    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
724        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
725
726        let cfg = ModelConfigMetadata {
727            max_seq_len: cfg.max_position_embeddings,
728            num_layers: cfg.num_hidden_layers,
729            hidden_size: cfg.hidden_size,
730            num_kv_heads: cfg.num_key_value_heads,
731            num_attn_heads: cfg.num_attention_heads,
732            sliding_window: cfg.sliding_window,
733            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
734            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
735        };
736
737        Ok(Box::new(cfg))
738    }
739}