1use std::{
2 fmt::{self, Debug, Display},
3 path::PathBuf,
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 attention::ATTENTION_CHUNK_SIZE,
10 embedding_models::{
11 embedding_gemma::{EmbeddingGemma, EmbeddingGemmaConfig},
12 qwen3_embedding::{Config as Qwen3EmbeddingConfig, Model as Qwen3EmbeddingModel},
13 },
14 matformer::MatformerSliceConfig,
15 pipeline::{loaders::auto_device_map::NonMappedSubModel, NormalLoadingMetadata},
16};
17
18use crate::{
19 amoe::AnyMoeBaseModelMixin,
20 device_map::DeviceMapper,
21 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
22 pipeline::{isq::IsqModelLoader, text_models_inputs_processor::FlashParams, IsqModel},
23 utils::varbuilder_utils::DeviceForLoadTensor,
24};
25use anyhow::Result;
26use candle_core::{DType, Device, Tensor};
27use mistralrs_quant::log::once_log_info;
28
29use mistralrs_quant::ShardedVarBuilder;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
35
36use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
37
38pub trait EmbeddingModel: IsqModel + AnyMoeBaseModelMixin {
39 #[allow(clippy::too_many_arguments)]
40 fn forward(
41 &self,
42 input_ids: &Tensor,
43 flash_params: &FlashParams,
44 ) -> candle_core::Result<Tensor>;
45 fn device(&self) -> &Device;
46}
47
48pub trait EmbeddingModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
49 fn load(
50 &self,
51 config: &str,
52 vb: ShardedVarBuilder,
53 normal_loading_metadata: NormalLoadingMetadata,
54 attention_mechanism: AttentionImplementation,
55 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>>;
56 fn is_gptx(&self, config: &str) -> Result<bool>;
57 fn has_causal_attention(&self, config: &str) -> Result<bool>;
58 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
59 fn get_device_for_tensor(
60 &self,
61 config: &str,
62 _mapper: &dyn DeviceMapper,
63 loading_isq: bool,
64 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
65 if loading_isq {
66 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
67 } else {
68 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
69 let num_layers = self.model_config(config)?.num_layers();
70 let closure = move |name: String| {
71 if let Some(captures) = re.captures(&name) {
72 captures
73 .get(1)
74 .and_then(|m| m.as_str().parse::<usize>().ok())
75 .map(|l| l.min(num_layers))
76 .map(DeviceForLoadTensor::Idx)
77 .unwrap_or(DeviceForLoadTensor::Base)
78 } else {
79 DeviceForLoadTensor::Base
80 }
81 };
82
83 Ok(Arc::new(closure))
84 }
85 }
86}
87
88#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
89#[derive(Clone, Debug, Deserialize, PartialEq)]
90pub enum EmbeddingLoaderType {
92 #[serde(rename = "embeddinggemma")]
93 EmbeddingGemma,
94 #[serde(rename = "qwen3embedding")]
95 Qwen3Embedding,
96}
97
98impl EmbeddingLoaderType {
100 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
101 match name {
102 "Gemma3TextModel" => Ok(Self::EmbeddingGemma),
103 "Qwen3ForCausalLM" => Ok(Self::Qwen3Embedding),
104 other => anyhow::bail!(
105 "Unsupported Hugging Face Transformers model class `{other}`. Please raise an issue."
106 ),
107 }
108 }
109}
110
111impl FromStr for EmbeddingLoaderType {
112 type Err = String;
113 fn from_str(s: &str) -> Result<Self, Self::Err> {
114 match s {
115 "embeddinggemma" => Ok(Self::EmbeddingGemma),
116 "qwen3embedding" => Ok(Self::Qwen3Embedding),
117 a => Err(format!(
118 "Unknown architecture `{a}`. Possible architectures: `embeddinggemma`, `qwen3embedding`."
119 )),
120 }
121 }
122}
123
124impl Display for EmbeddingLoaderType {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::EmbeddingGemma => write!(f, "embeddinggemma"),
128 Self::Qwen3Embedding => write!(f, "qwen3embedding"),
129 }
130 }
131}
132
133#[derive(Clone, Debug, Deserialize)]
134pub enum EmbeddingModulePaths {
135 Transformer {
136 path: String,
137 },
138 Pooling {
139 path: String,
140 config: PathBuf,
141 },
142 Dense {
143 path: String,
144 config: PathBuf,
145 model: PathBuf,
146 },
147 Normalize {
148 path: String,
149 },
150}
151
152impl EmbeddingModulePaths {
153 pub fn serialize_modules(modules: &[EmbeddingModulePaths]) -> String {
154 #[derive(Serialize)]
155 struct OutputModule {
156 idx: usize,
157 name: String,
158 path: String,
159 #[serde(rename = "type")]
160 ty: String,
161 }
162
163 let mapped: Vec<OutputModule> = modules
164 .iter()
165 .enumerate()
166 .map(|(i, m)| {
167 let (path, ty) = match m {
168 EmbeddingModulePaths::Transformer { path } => (
169 path.clone(),
170 "sentence_transformers.models.Transformer".to_string(),
171 ),
172 EmbeddingModulePaths::Pooling { path, .. } => (
173 path.clone(),
174 "sentence_transformers.models.Pooling".to_string(),
175 ),
176 EmbeddingModulePaths::Dense { path, .. } => (
177 path.clone(),
178 "sentence_transformers.models.Dense".to_string(),
179 ),
180 EmbeddingModulePaths::Normalize { path } => (
181 path.clone(),
182 "sentence_transformers.models.Normalize".to_string(),
183 ),
184 };
185
186 OutputModule {
187 idx: i,
188 name: i.to_string(),
189 path,
190 ty,
191 }
192 })
193 .collect();
194
195 serde_json::to_string_pretty(&mapped).unwrap()
196 }
197}
198
199#[derive(Debug, Deserialize)]
200pub struct EmbeddingModule {
201 pub path: String,
202 #[serde(rename = "type", deserialize_with = "deserialize_module_type")]
203 pub ty: EmbeddingModuleType,
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum EmbeddingModuleType {
208 Transformer,
209 Pooling,
210 Dense,
211 Normalize,
212}
213
214fn deserialize_module_type<'de, D>(deserializer: D) -> Result<EmbeddingModuleType, D::Error>
215where
216 D: Deserializer<'de>,
217{
218 struct ModuleTypeVisitor;
219
220 impl<'de> Visitor<'de> for ModuleTypeVisitor {
221 type Value = EmbeddingModuleType;
222
223 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
224 f.write_str("a sentence-transformers module type string")
225 }
226
227 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
228 where
229 E: serde::de::Error,
230 {
231 let last = v.rsplit('.').next().unwrap_or(v).to_ascii_lowercase();
233 match last.as_str() {
234 "transformer" => Ok(EmbeddingModuleType::Transformer),
235 "pooling" => Ok(EmbeddingModuleType::Pooling),
236 "dense" => Ok(EmbeddingModuleType::Dense),
237 "normalize" => Ok(EmbeddingModuleType::Normalize),
238 _ => Err(E::invalid_value(
239 serde::de::Unexpected::Str(v),
240 &"Transformer/Pooling/Dense/Normalize",
241 )),
242 }
243 }
244 }
245
246 deserializer.deserialize_str(ModuleTypeVisitor)
247}
248
249macro_rules! bias_if {
250 ($cond:expr, $size:expr) => {
251 if $cond {
252 $size
253 } else {
254 0
255 }
256 };
257}
258
259pub struct AutoEmbeddingLoader;
261
262#[derive(Deserialize)]
263struct AutoEmbeddingLoaderConfig {
264 architectures: Vec<String>,
265}
266
267impl AutoEmbeddingLoader {
268 fn get_loader(config: &str) -> Result<Box<dyn EmbeddingModelLoader>> {
269 let auto_cfg: AutoEmbeddingLoaderConfig = serde_json::from_str(config)?;
270 if auto_cfg.architectures.len() != 1 {
271 anyhow::bail!("Expected to have one name for `architectures` config field.")
272 }
273
274 let name = &auto_cfg.architectures[0];
275
276 let tp = EmbeddingLoaderType::from_causal_lm_name(name)?;
277
278 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
279
280 match tp {
281 EmbeddingLoaderType::EmbeddingGemma => Ok(Box::new(EmbeddingGemmaLoader)),
282 EmbeddingLoaderType::Qwen3Embedding => Ok(Box::new(Qwen3EmbeddingLoader)),
283 }
284 }
285}
286
287impl EmbeddingModelLoader for AutoEmbeddingLoader {
288 fn load(
289 &self,
290 config: &str,
291 vb: ShardedVarBuilder,
292 normal_loading_metadata: NormalLoadingMetadata,
293 attention_mechanism: AttentionImplementation,
294 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
295 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
296 }
297 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
298 Self::get_loader(config)?.get_config_repr(config)
299 }
300 fn has_causal_attention(&self, config: &str) -> Result<bool> {
301 Self::get_loader(config)?.has_causal_attention(config)
302 }
303 fn is_gptx(&self, config: &str) -> Result<bool> {
304 Self::get_loader(config)?.is_gptx(config)
305 }
306}
307
308impl IsqModelLoader for AutoEmbeddingLoader {
309 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
310 Self::get_loader(config)?.immediate_isq_predicates(config)
311 }
312 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
313 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
314 }
315 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
316 Self::get_loader(config)?.isq_layer_regexes(config)
317 }
318 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
319 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
320 }
321}
322
323impl DeviceMappedModelLoader for AutoEmbeddingLoader {
324 fn non_mapped_size_in_bytes(
325 &self,
326 config: &str,
327 dtype: DType,
328 weight_pack_factor: usize,
329 _matformer_config: Option<&MatformerSliceConfig>,
330 ) -> Result<usize> {
331 Self::get_loader(config)?.non_mapped_size_in_bytes(
332 config,
333 dtype,
334 weight_pack_factor,
335 _matformer_config,
336 )
337 }
338 fn num_layers(&self, config: &str) -> Result<usize> {
339 Self::get_loader(config)?.num_layers(config)
340 }
341 fn layer_sizes_in_bytes(
342 &self,
343 config: &str,
344 dtype: DType,
345 weight_pack_factor: usize,
346 _matformer_config: Option<&MatformerSliceConfig>,
347 ) -> Result<Vec<usize>> {
348 Self::get_loader(config)?.layer_sizes_in_bytes(
349 config,
350 dtype,
351 weight_pack_factor,
352 _matformer_config,
353 )
354 }
355 fn mapped_max_act_size_elems(
356 &self,
357 config: &str,
358 params: &super::AutoDeviceMapParams,
359 ) -> Result<usize> {
360 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
361 }
362 fn non_mapped_max_act_size_elems(
363 &self,
364 _config: &str,
365 _params: &AutoDeviceMapParams,
366 ) -> Result<usize> {
367 Ok(0)
368 }
369 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
370 Self::get_loader(config)?.model_config(config)
371 }
372}
373
374pub struct EmbeddingGemmaLoader;
378
379impl EmbeddingModelLoader for EmbeddingGemmaLoader {
380 fn load(
381 &self,
382 config: &str,
383 vb: ShardedVarBuilder,
384 normal_loading_metadata: NormalLoadingMetadata,
385 attention_mechanism: AttentionImplementation,
386 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
387 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
388
389 Ok(Box::new(EmbeddingGemma::new(
390 &cfg,
391 vb,
392 self.is_gptx(config)?,
393 normal_loading_metadata,
394 attention_mechanism,
395 )?))
396 }
397 fn is_gptx(&self, _: &str) -> Result<bool> {
398 Ok(true)
399 }
400 fn has_causal_attention(&self, _: &str) -> Result<bool> {
401 Ok(false)
402 }
403 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
404 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
405 Ok(Box::new(cfg))
406 }
407}
408
409impl IsqModelLoader for EmbeddingGemmaLoader {
410 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
411 Ok(vec![
412 Regex::new(r"lm_head\.(weight|bias)$")?,
413 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
415 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
416 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
417 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
418 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
420 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
421 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
422 ])
423 }
424 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
425 Ok(vec![
426 Regex::new(r"lm_head\.(weight|bias)$")?,
427 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
429 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
430 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
431 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
432 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
434 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
435 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
436 ])
437 }
438}
439
440impl DeviceMappedModelLoader for EmbeddingGemmaLoader {
441 fn mapped_max_act_size_elems(
442 &self,
443 config: &str,
444 params: &AutoDeviceMapParams,
445 ) -> Result<usize> {
446 let AutoDeviceMapParams::Text {
447 max_seq_len,
448 max_batch_size,
449 } = params
450 else {
451 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
452 };
453
454 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
455
456 Ok(
457 max_batch_size
458 * cfg.num_attention_heads
459 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
460 )
461 }
462
463 fn non_mapped_max_act_size_elems(
464 &self,
465 _config: &str,
466 _params: &AutoDeviceMapParams,
467 ) -> Result<usize> {
468 Ok(0)
469 }
470
471 fn non_mapped_size_in_bytes(
472 &self,
473 config: &str,
474 dtype: DType,
475 weight_pack_factor: usize,
476 _matformer_config: Option<&MatformerSliceConfig>,
477 ) -> Result<usize> {
478 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
479
480 let elems = {
481 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
482 let norm = cfg.hidden_size;
483 embed_tokens + norm
484 };
485 Ok(elems * dtype.size_in_bytes())
486 }
487
488 fn layer_sizes_in_bytes(
489 &self,
490 config: &str,
491 dtype: DType,
492 weight_pack_factor: usize,
493 _matformer_config: Option<&MatformerSliceConfig>,
494 ) -> Result<Vec<usize>> {
495 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
496
497 let per_layer_elems = {
498 let input_layernorm = cfg.hidden_size;
499 let post_attention_layernorm = cfg.hidden_size;
500
501 let size_in = cfg.hidden_size;
502 let size_q = cfg.head_dim * cfg.num_attention_heads;
503 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
504 let q_proj =
505 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
506 let k_proj =
507 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
508 let v_proj =
509 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
510 let o_proj =
511 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
512
513 let h_size = cfg.hidden_size;
514 let i_size = cfg.intermediate_size;
515 let gate_proj = h_size * i_size / weight_pack_factor;
516 let up_proj = h_size * i_size / weight_pack_factor;
517 let down_proj = i_size * h_size / weight_pack_factor;
518
519 input_layernorm
520 + post_attention_layernorm
521 + q_proj
522 + k_proj
523 + v_proj
524 + o_proj
525 + gate_proj
526 + up_proj
527 + down_proj
528 };
529 Ok(vec![
530 per_layer_elems * dtype.size_in_bytes();
531 cfg.num_hidden_layers
532 ])
533 }
534
535 fn num_layers(&self, config: &str) -> Result<usize> {
536 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
537
538 Ok(cfg.num_hidden_layers)
539 }
540
541 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
542 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
543
544 let cfg = ModelConfigMetadata {
545 max_seq_len: cfg.max_position_embeddings,
546 num_layers: cfg.num_hidden_layers,
547 hidden_size: cfg.hidden_size,
548 num_kv_heads: cfg.num_key_value_heads,
549 num_attn_heads: cfg.num_attention_heads,
550 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
552 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
553 };
554
555 Ok(Box::new(cfg))
556 }
557
558 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
559 None }
561}
562
563pub struct Qwen3EmbeddingLoader;
567
568impl EmbeddingModelLoader for Qwen3EmbeddingLoader {
569 fn load(
570 &self,
571 config: &str,
572 vb: ShardedVarBuilder,
573 normal_loading_metadata: NormalLoadingMetadata,
574 attention_mechanism: AttentionImplementation,
575 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
576 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
577
578 Ok(Box::new(Qwen3EmbeddingModel::new(
579 &cfg,
580 vb,
581 self.is_gptx(config)?,
582 normal_loading_metadata,
583 attention_mechanism,
584 )?))
585 }
586 fn has_causal_attention(&self, _: &str) -> Result<bool> {
587 Ok(true)
588 }
589 fn is_gptx(&self, _: &str) -> Result<bool> {
590 Ok(true)
591 }
592 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
593 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
594
595 Ok(Box::new(cfg))
596 }
597}
598
599impl IsqModelLoader for Qwen3EmbeddingLoader {
600 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
601 Ok(vec![
602 Regex::new(r"lm_head\.(weight|bias)$")?,
603 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
605 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
606 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
607 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
608 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
610 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
611 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
612 ])
613 }
614 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
615 self.isq_layer_regexes(config)
616 }
617}
618
619impl DeviceMappedModelLoader for Qwen3EmbeddingLoader {
620 fn mapped_max_act_size_elems(
621 &self,
622 config: &str,
623 params: &AutoDeviceMapParams,
624 ) -> Result<usize> {
625 let AutoDeviceMapParams::Text {
626 max_seq_len,
627 max_batch_size,
628 } = params
629 else {
630 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
631 };
632
633 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
634
635 Ok(
636 max_batch_size
637 * cfg.num_attention_heads
638 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
639 )
640 }
641 fn non_mapped_max_act_size_elems(
642 &self,
643 _config: &str,
644 _params: &AutoDeviceMapParams,
645 ) -> Result<usize> {
646 Ok(0)
647 }
648
649 fn non_mapped_size_in_bytes(
650 &self,
651 config: &str,
652 dtype: DType,
653 weight_pack_factor: usize,
654 _matformer_config: Option<&MatformerSliceConfig>,
655 ) -> Result<usize> {
656 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
657 let elems = {
658 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
659 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
661 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
662 } else {
663 0
664 };
665 let norm = cfg.hidden_size;
666 embed_tokens + lm_head + norm
667 };
668 Ok(elems * dtype.size_in_bytes())
669 }
670
671 fn layer_sizes_in_bytes(
672 &self,
673 config: &str,
674 dtype: DType,
675 weight_pack_factor: usize,
676 _matformer_config: Option<&MatformerSliceConfig>,
677 ) -> Result<Vec<usize>> {
678 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
679 let per_layer_elems = {
680 let input_layernorm = cfg.hidden_size;
681 let post_attention_layernorm = cfg.hidden_size;
682
683 let size_in = cfg.hidden_size;
684 let size_q = cfg.head_dim() * cfg.num_attention_heads;
685 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
686 let q_proj = size_in * size_q / weight_pack_factor + size_q;
687 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
688 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
689 let o_proj = size_q * size_in / weight_pack_factor;
690
691 let h_size = cfg.hidden_size;
692 let i_size = cfg.intermediate_size;
693 let gate_proj = h_size * i_size / weight_pack_factor;
694 let up_proj = h_size * i_size / weight_pack_factor;
695 let down_proj = i_size * h_size / weight_pack_factor;
696
697 let q_norm = cfg.head_dim();
698 let k_norm = cfg.head_dim();
699
700 input_layernorm
701 + post_attention_layernorm
702 + q_proj
703 + k_proj
704 + v_proj
705 + o_proj
706 + gate_proj
707 + up_proj
708 + down_proj
709 + q_norm
710 + k_norm
711 };
712 Ok(vec![
713 per_layer_elems * dtype.size_in_bytes();
714 cfg.num_hidden_layers
715 ])
716 }
717
718 fn num_layers(&self, config: &str) -> Result<usize> {
719 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
720 Ok(cfg.num_hidden_layers)
721 }
722
723 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
724 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
725
726 let cfg = ModelConfigMetadata {
727 max_seq_len: cfg.max_position_embeddings,
728 num_layers: cfg.num_hidden_layers,
729 hidden_size: cfg.hidden_size,
730 num_kv_heads: cfg.num_key_value_heads,
731 num_attn_heads: cfg.num_attention_heads,
732 sliding_window: cfg.sliding_window,
733 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
734 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
735 };
736
737 Ok(Box::new(cfg))
738 }
739}