1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 amoe::AnyMoeBaseModelMixin,
10 device_map::DeviceMapper,
11 layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig},
12 lora::{LoraConfig, Ordering},
13 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
14 pipeline::{
15 isq::IsqModelLoader,
16 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17 EitherCache, IsqModel,
18 },
19 serde_default_fn,
20 utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25
26use indicatif::MultiProgress;
27use mistralrs_quant::{QuantizedConfig, ShardedVarBuilder};
28#[cfg(feature = "pyo3_macros")]
29use pyo3::pyclass;
30
31use regex::Regex;
32use serde::Deserialize;
33
34use crate::{
35 models,
36 xlora_models::{self, XLoraConfig},
37};
38
39use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
40
41pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
42 #[allow(clippy::too_many_arguments)]
43 fn forward(
44 &self,
45 input_ids: &Tensor,
46 seqlen_offsets: &[usize],
47 context_lens: Vec<(usize, usize)>,
48 position_ids: Vec<usize>,
49 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
50 flash_params: &FlashParams,
51 ) -> candle_core::Result<Tensor>;
52 #[allow(clippy::too_many_arguments)]
53 fn xlora_forward(
54 &self,
55 input_ids: &Tensor,
56 input_ids_full: &Tensor,
57 seqlen_offsets: &[usize],
58 seqlen_offsets_full: &[usize],
59 no_kv_cache: bool,
60 non_granular_state: &Option<NonGranularState>,
61 context_lens: Vec<(usize, usize)>,
62 position_ids: Vec<usize>,
63 flash_params: &FlashParams,
64 flash_params_full: &FlashParams,
65 ) -> candle_core::Result<Tensor>;
66 fn is_xlora(&self) -> bool;
67 fn device(&self) -> &Device;
68 fn cache(&self) -> &EitherCache;
69 fn cache_mut(&mut self) -> &mut EitherCache;
70 fn max_seq_len(&self) -> usize;
71 fn config(&self) -> &ModelConfigMetadata;
72}
73
74pub struct NormalLoadingMetadata {
76 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
78 pub loading_isq: bool,
80 pub real_device: Device,
82 pub multi_progress: Arc<MultiProgress>,
84}
85
86pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87 fn load(
88 &self,
89 config: &str,
90 use_flash_attn: bool,
91 vb: ShardedVarBuilder,
92 normal_loading_metadata: NormalLoadingMetadata,
93 attention_mechanism: AttentionImplementation,
94 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
95 #[allow(clippy::too_many_arguments)]
96 fn load_xlora(
97 &self,
98 config: &str,
99 use_flash_attn: bool,
100 vb: ShardedVarBuilder,
101 lora_config: &[((String, String), LoraConfig)],
102 xlora_config: Option<XLoraConfig>,
103 xlora_ordering: Ordering,
104 normal_loading_metadata: NormalLoadingMetadata,
105 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
106 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
107 fn is_gptx(&self, config: &str) -> Result<bool>;
108 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
109 Ok(true)
110 }
111 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
112 fn get_device_for_tensor(
113 &self,
114 config: &str,
115 _mapper: &dyn DeviceMapper,
116 loading_isq: bool,
117 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
118 if loading_isq {
119 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
120 } else {
121 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
122 let num_layers = self.model_config(config)?.num_layers();
123 let closure = move |name: String| {
124 if let Some(captures) = re.captures(&name) {
125 captures
126 .get(1)
127 .and_then(|m| m.as_str().parse::<usize>().ok())
128 .map(|l| l.min(num_layers))
129 .map(DeviceForLoadTensor::Idx)
130 .unwrap_or(DeviceForLoadTensor::Base)
131 } else {
132 DeviceForLoadTensor::Base
133 }
134 };
135
136 Ok(Arc::new(closure))
137 }
138 }
139}
140
141#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
142#[derive(Clone, Debug, Deserialize, PartialEq)]
143pub enum NormalLoaderType {
145 #[serde(rename = "mistral")]
146 Mistral,
147 #[serde(rename = "gemma")]
148 Gemma,
149 #[serde(rename = "mixtral")]
150 Mixtral,
151 #[serde(rename = "llama")]
152 Llama,
153 #[serde(rename = "phi2")]
154 Phi2,
155 #[serde(rename = "phi3")]
156 Phi3,
157 #[serde(rename = "qwen2")]
158 Qwen2,
159 #[serde(rename = "gemma2")]
160 Gemma2,
161 #[serde(rename = "starcoder2")]
162 Starcoder2,
163 #[serde(rename = "phi3.5moe")]
164 Phi3_5MoE,
165 #[serde(rename = "deepseekv2")]
166 DeepSeekV2,
167 #[serde(rename = "deepseekv3")]
168 DeepSeekV3,
169 #[serde(rename = "qwen3")]
170 Qwen3,
171 #[serde(rename = "qwen3moe")]
172 Qwen3Moe,
173}
174
175impl NormalLoaderType {
178 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179 match name {
180 "MistralForCausalLM" => Ok(Self::Mistral),
181 "MixtralForCausalLM" => Ok(Self::Mixtral),
182 "GemmaForCausalLM" => Ok(Self::Gemma),
183 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
184 "PhiForCausalLM" => Ok(Self::Phi2),
185 "Phi3ForCausalLM" => Ok(Self::Phi3),
186 "LlamaForCausalLM" => Ok(Self::Llama),
187 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
188 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
189 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
190 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
191 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
192 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
193 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
194 other => anyhow::bail!(
195 "Unsupported Huggging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
196 ),
197 }
198 }
199}
200
201impl FromStr for NormalLoaderType {
202 type Err = String;
203 fn from_str(s: &str) -> Result<Self, Self::Err> {
204 match s {
205 "mistral" => Ok(Self::Mistral),
206 "gemma" => Ok(Self::Gemma),
207 "mixtral" => Ok(Self::Mixtral),
208 "llama" => Ok(Self::Llama),
209 "phi2" => Ok(Self::Phi2),
210 "phi3" => Ok(Self::Phi3),
211 "qwen2" => Ok(Self::Qwen2),
212 "gemma2" => Ok(Self::Gemma2),
213 "starcoder2" => Ok(Self::Starcoder2),
214 "phi3.5moe" => Ok(Self::Phi3_5MoE),
215 "deepseekv2" => Ok(Self::DeepSeekV2),
216 "deepseekv3" => Ok(Self::DeepSeekV3),
217 "qwen3" => Ok(Self::DeepSeekV3),
218 "qwen3moe" => Ok(Self::Qwen3Moe),
219 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `qwen3moe`.")),
220 }
221 }
222}
223
224impl Display for NormalLoaderType {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Gemma => write!(f, "gemma"),
228 Self::Gemma2 => write!(f, "gemma2"),
229 Self::Llama => write!(f, "llama"),
230 Self::Mistral => write!(f, "mistral"),
231 Self::Mixtral => write!(f, "mixtral"),
232 Self::Phi2 => write!(f, "phi2"),
233 Self::Phi3 => write!(f, "phi3"),
234 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
235 Self::Qwen2 => write!(f, "qwen2"),
236 Self::Starcoder2 => write!(f, "starcoder2"),
237 Self::DeepSeekV2 => write!(f, "deepseekv2"),
238 Self::DeepSeekV3 => write!(f, "deepseekv3"),
239 Self::Qwen3 => write!(f, "qwen3"),
240 Self::Qwen3Moe => write!(f, "qwen3moe"),
241 }
242 }
243}
244
245macro_rules! bias_if {
246 ($cond:expr, $size:expr) => {
247 if $cond {
248 $size
249 } else {
250 0
251 }
252 };
253}
254
255pub struct AutoLoader;
257
258#[derive(Deserialize)]
259struct AutoLoaderConfig {
260 architectures: Vec<String>,
261}
262
263impl AutoLoader {
264 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
265 let auto_cfg: AutoLoaderConfig = serde_json::from_str(config)?;
266 if auto_cfg.architectures.len() != 1 {
267 anyhow::bail!("Expected to have one name for `architectures` config field.")
268 }
269
270 let name = &auto_cfg.architectures[0];
271
272 let tp = NormalLoaderType::from_causal_lm_name(name)?;
273
274 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
275
276 match tp {
277 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
278 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
279 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
280 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
281 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
282 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
283 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
284 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
285 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
286 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
287 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
288 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
289 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
290 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
291 }
292 }
293}
294
295impl NormalModelLoader for AutoLoader {
296 fn load(
297 &self,
298 config: &str,
299 use_flash_attn: bool,
300 vb: ShardedVarBuilder,
301 normal_loading_metadata: NormalLoadingMetadata,
302 attention_mechanism: AttentionImplementation,
303 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
304 Self::get_loader(config)?.load(
305 config,
306 use_flash_attn,
307 vb,
308 normal_loading_metadata,
309 attention_mechanism,
310 )
311 }
312 fn load_xlora(
313 &self,
314 config: &str,
315 use_flash_attn: bool,
316 vb: ShardedVarBuilder,
317 lora_config: &[((String, String), LoraConfig)],
318 xlora_config: Option<XLoraConfig>,
319 xlora_ordering: Ordering,
320 normal_loading_metadata: NormalLoadingMetadata,
321 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
322 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
323 Self::get_loader(config)?.load_xlora(
324 config,
325 use_flash_attn,
326 vb,
327 lora_config,
328 xlora_config,
329 xlora_ordering,
330 normal_loading_metadata,
331 preload_adapters,
332 )
333 }
334 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
335 Self::get_loader(config)?.get_config_repr(config, use_flash_attn)
336 }
337 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
338 Self::get_loader(config)?.supports_paged_attention(config)
339 }
340 fn is_gptx(&self, config: &str) -> Result<bool> {
341 Self::get_loader(config)?.is_gptx(config)
342 }
343}
344
345impl IsqModelLoader for AutoLoader {
346 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
347 Self::get_loader(config)?.isq_layer_regexes(config)
348 }
349}
350
351impl DeviceMappedModelLoader for AutoLoader {
352 fn non_mapped_size_in_bytes(
353 &self,
354 config: &str,
355 dtype: DType,
356 weight_pack_factor: usize,
357 ) -> Result<usize> {
358 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
359 }
360 fn num_layers(&self, config: &str) -> Result<usize> {
361 Self::get_loader(config)?.num_layers(config)
362 }
363 fn layer_sizes_in_bytes(
364 &self,
365 config: &str,
366 dtype: DType,
367 weight_pack_factor: usize,
368 ) -> Result<Vec<usize>> {
369 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
370 }
371 fn mapped_max_act_size_elems(
372 &self,
373 config: &str,
374 params: &super::AutoDeviceMapParams,
375 prompt_chunksize: usize,
376 ) -> Result<usize> {
377 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
378 }
379 fn non_mapped_max_act_size_elems(
380 &self,
381 _config: &str,
382 _params: &AutoDeviceMapParams,
383 ) -> Result<usize> {
384 Ok(0)
385 }
386 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
387 Self::get_loader(config)?.model_config(config)
388 }
389}
390
391serde_default_fn!(bool, word_emb_default, false);
392
393#[derive(Deserialize, Debug)]
396struct MistralBasicConfig {
397 vocab_size: usize,
398 hidden_size: usize,
399 intermediate_size: usize,
400 num_hidden_layers: usize,
401 num_attention_heads: usize,
402 num_key_value_heads: usize,
403 hidden_act: Activation,
404 max_position_embeddings: usize,
405 rms_norm_eps: f64,
406 rope_theta: f64,
407 sliding_window: Option<usize>,
408 head_dim: Option<usize>,
409 quantization_config: Option<QuantizedConfig>,
410 #[serde(default = "word_emb_default")]
411 tie_word_embeddings: bool,
412}
413
414impl MistralBasicConfig {
415 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mistral::Config> {
416 let basic_config: Self = serde_json::from_str(slice)?;
417 Ok(models::mistral::Config {
418 vocab_size: basic_config.vocab_size,
419 hidden_size: basic_config.hidden_size,
420 intermediate_size: basic_config.intermediate_size,
421 num_hidden_layers: basic_config.num_hidden_layers,
422 num_attention_heads: basic_config.num_attention_heads,
423 num_key_value_heads: basic_config.num_key_value_heads,
424 hidden_act: basic_config.hidden_act,
425 max_position_embeddings: basic_config.max_position_embeddings,
426 rms_norm_eps: basic_config.rms_norm_eps,
427 rope_theta: basic_config.rope_theta,
428 sliding_window: basic_config.sliding_window,
429 use_flash_attn,
430 head_dim: basic_config.head_dim,
431 quantization_config: basic_config.quantization_config,
432 tie_word_embeddings: basic_config.tie_word_embeddings,
433 })
434 }
435}
436
437pub struct MistralLoader;
438
439impl NormalModelLoader for MistralLoader {
440 fn load(
441 &self,
442 config: &str,
443 use_flash_attn: bool,
444 vb: ShardedVarBuilder,
445 normal_loading_metadata: NormalLoadingMetadata,
446 attention_mechanism: AttentionImplementation,
447 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
448 Ok(Box::new(models::mistral::Model::new(
449 &MistralBasicConfig::deserialize(config, use_flash_attn)?,
450 vb,
451 self.is_gptx(config)?,
452 normal_loading_metadata,
453 attention_mechanism,
454 )?))
455 }
456 fn load_xlora(
457 &self,
458 config: &str,
459 use_flash_attn: bool,
460 vb: ShardedVarBuilder,
461 lora_config: &[((String, String), LoraConfig)],
462 xlora_config: Option<XLoraConfig>,
463 xlora_ordering: Ordering,
464 normal_loading_metadata: NormalLoadingMetadata,
465 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
466 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
467 Ok(Box::new(xlora_models::XLoraMistral::new(
468 &MistralBasicConfig::deserialize(config, use_flash_attn)?,
469 vb,
470 lora_config,
471 xlora_config,
472 xlora_ordering,
473 self.is_gptx(config)?,
474 normal_loading_metadata,
475 preload_adapters,
476 )?))
477 }
478 fn is_gptx(&self, _: &str) -> Result<bool> {
479 Ok(true)
480 }
481 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
482 Ok(Box::new(MistralBasicConfig::deserialize(
483 config,
484 use_flash_attn,
485 )?))
486 }
487}
488
489impl IsqModelLoader for MistralLoader {
490 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
491 Ok(vec![
492 Regex::new(r"lm_head\.(weight|bias)$")?,
493 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
495 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
496 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
497 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
498 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
500 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
501 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
502 ])
503 }
504}
505
506impl DeviceMappedModelLoader for MistralLoader {
507 fn mapped_max_act_size_elems(
508 &self,
509 config: &str,
510 params: &AutoDeviceMapParams,
511 prompt_chunksize: usize,
512 ) -> Result<usize> {
513 let AutoDeviceMapParams::Text {
514 max_seq_len: _,
515 max_batch_size,
516 } = params
517 else {
518 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
519 };
520
521 let cfg = MistralBasicConfig::deserialize(config, false)?;
522
523 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
524 }
525 fn non_mapped_max_act_size_elems(
526 &self,
527 _config: &str,
528 _params: &AutoDeviceMapParams,
529 ) -> Result<usize> {
530 Ok(0)
531 }
532
533 fn non_mapped_size_in_bytes(
534 &self,
535 config: &str,
536 dtype: DType,
537 weight_pack_factor: usize,
538 ) -> Result<usize> {
539 let cfg = MistralBasicConfig::deserialize(config, false)?;
540 let elems = {
541 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
542 let lm_head = if !cfg.tie_word_embeddings {
543 cfg.hidden_size * cfg.vocab_size
544 } else {
545 0
546 };
547 let norm = cfg.hidden_size;
548 embed_tokens + lm_head + norm
549 };
550 Ok(elems * dtype.size_in_bytes())
551 }
552
553 fn layer_sizes_in_bytes(
554 &self,
555 config: &str,
556 dtype: DType,
557 weight_pack_factor: usize,
558 ) -> Result<Vec<usize>> {
559 let cfg = MistralBasicConfig::deserialize(config, false)?;
560 let per_layer_elems = {
561 let input_layernorm = cfg.hidden_size;
562 let post_attention_layernorm = cfg.hidden_size;
563
564 let size_in = cfg.hidden_size;
565 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
566 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
567 let q_proj = size_in * size_q / weight_pack_factor;
568 let k_proj = size_in * size_kv / weight_pack_factor;
569 let v_proj = size_in * size_kv / weight_pack_factor;
570 let o_proj = size_q * size_in / weight_pack_factor;
571
572 let h_size = cfg.hidden_size;
573 let i_size = cfg.intermediate_size;
574 let gate_proj = h_size * i_size / weight_pack_factor;
575 let up_proj = h_size * i_size / weight_pack_factor;
576 let down_proj = i_size * h_size / weight_pack_factor;
577
578 input_layernorm
579 + post_attention_layernorm
580 + q_proj
581 + k_proj
582 + v_proj
583 + o_proj
584 + gate_proj
585 + up_proj
586 + down_proj
587 };
588 Ok(vec![
589 per_layer_elems * dtype.size_in_bytes();
590 cfg.num_hidden_layers
591 ])
592 }
593
594 fn num_layers(&self, config: &str) -> Result<usize> {
595 let cfg = MistralBasicConfig::deserialize(config, false)?;
596 Ok(cfg.num_hidden_layers)
597 }
598
599 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
600 let cfg = MistralBasicConfig::deserialize(config, false)?;
601
602 let cfg = ModelConfigMetadata {
603 max_seq_len: cfg.max_position_embeddings,
604 num_layers: cfg.num_hidden_layers,
605 hidden_size: cfg.hidden_size,
606 num_kv_heads: cfg.num_key_value_heads,
607 num_attn_heads: cfg.num_attention_heads,
608 sliding_window: cfg.sliding_window,
609 k_head_dim: cfg.head_dim(),
610 v_head_dim: cfg.head_dim(),
611 };
612
613 Ok(Box::new(cfg))
614 }
615}
616
617fn default_max_position_embeddings() -> usize {
620 4096
621}
622
623#[derive(Deserialize)]
624struct GemmaBasicConfig {
625 attention_bias: bool,
626 head_dim: usize,
627 hidden_act: Option<Activation>,
629 hidden_activation: Option<Activation>,
630 hidden_size: usize,
631 intermediate_size: usize,
632 num_attention_heads: usize,
633 num_hidden_layers: usize,
634 num_key_value_heads: usize,
635 rms_norm_eps: f64,
636 rope_theta: f64,
637 vocab_size: usize,
638
639 #[serde(default = "default_max_position_embeddings")]
640 max_position_embeddings: usize,
641 quantization_config: Option<QuantizedConfig>,
642 #[serde(default = "word_emb_default")]
643 tie_word_embeddings: bool,
644}
645
646impl GemmaBasicConfig {
647 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma::Config> {
648 let basic_config: Self = serde_json::from_str(slice)?;
649 Ok(models::gemma::Config {
650 vocab_size: basic_config.vocab_size,
651 hidden_size: basic_config.hidden_size,
652 intermediate_size: basic_config.intermediate_size,
653 num_hidden_layers: basic_config.num_hidden_layers,
654 num_attention_heads: basic_config.num_attention_heads,
655 num_key_value_heads: basic_config.num_key_value_heads,
656 hidden_act: basic_config.hidden_act,
657 hidden_activation: basic_config.hidden_activation,
658 max_position_embeddings: basic_config.max_position_embeddings,
659 rms_norm_eps: basic_config.rms_norm_eps,
660 rope_theta: basic_config.rope_theta,
661 attention_bias: basic_config.attention_bias,
662 head_dim: basic_config.head_dim,
663 use_flash_attn,
664 quantization_config: basic_config.quantization_config,
665 tie_word_embeddings: basic_config.tie_word_embeddings,
666 })
667 }
668}
669
670pub struct GemmaLoader;
674
675impl NormalModelLoader for GemmaLoader {
676 fn load(
677 &self,
678 config: &str,
679 use_flash_attn: bool,
680 vb: ShardedVarBuilder,
681 normal_loading_metadata: NormalLoadingMetadata,
682 attention_mechanism: AttentionImplementation,
683 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
684 Ok(Box::new(models::gemma::Model::new(
685 &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
686 vb,
687 self.is_gptx(config)?,
688 normal_loading_metadata,
689 attention_mechanism,
690 )?))
691 }
692 fn load_xlora(
693 &self,
694 config: &str,
695 use_flash_attn: bool,
696 vb: ShardedVarBuilder,
697 lora_config: &[((String, String), LoraConfig)],
698 xlora_config: Option<XLoraConfig>,
699 xlora_ordering: Ordering,
700 normal_loading_metadata: NormalLoadingMetadata,
701 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
702 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
703 Ok(Box::new(xlora_models::XLoraGemma::new(
704 &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
705 vb,
706 lora_config,
707 xlora_config,
708 xlora_ordering,
709 self.is_gptx(config)?,
710 normal_loading_metadata,
711 preload_adapters,
712 )?))
713 }
714 fn is_gptx(&self, _: &str) -> Result<bool> {
715 Ok(true)
716 }
717 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
718 Ok(Box::new(GemmaBasicConfig::deserialize(
719 config,
720 use_flash_attn,
721 )?))
722 }
723}
724
725impl IsqModelLoader for GemmaLoader {
726 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
727 Ok(vec![
728 Regex::new(r"lm_head\.(weight|bias)$")?,
729 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
731 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
732 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
733 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
734 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
736 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
737 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
738 ])
739 }
740}
741
742impl DeviceMappedModelLoader for GemmaLoader {
743 fn mapped_max_act_size_elems(
744 &self,
745 config: &str,
746 params: &AutoDeviceMapParams,
747 prompt_chunksize: usize,
748 ) -> Result<usize> {
749 let AutoDeviceMapParams::Text {
750 max_seq_len: _,
751 max_batch_size,
752 } = params
753 else {
754 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
755 };
756
757 let cfg = GemmaBasicConfig::deserialize(config, false)?;
758
759 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
760 }
761 fn non_mapped_max_act_size_elems(
762 &self,
763 _config: &str,
764 _params: &AutoDeviceMapParams,
765 ) -> Result<usize> {
766 Ok(0)
767 }
768
769 fn non_mapped_size_in_bytes(
770 &self,
771 config: &str,
772 dtype: DType,
773 weight_pack_factor: usize,
774 ) -> Result<usize> {
775 let cfg = GemmaBasicConfig::deserialize(config, false)?;
776 let elems = {
777 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
778 let lm_head = if !cfg.tie_word_embeddings {
779 cfg.hidden_size * cfg.vocab_size
780 } else {
781 0
782 };
783 let norm = cfg.hidden_size;
784 embed_tokens + lm_head + norm
785 };
786 Ok(elems * dtype.size_in_bytes())
787 }
788
789 fn layer_sizes_in_bytes(
790 &self,
791 config: &str,
792 dtype: DType,
793 weight_pack_factor: usize,
794 ) -> Result<Vec<usize>> {
795 let cfg = GemmaBasicConfig::deserialize(config, false)?;
796 let per_layer_elems = {
797 let input_layernorm = cfg.hidden_size;
798 let post_attention_layernorm = cfg.hidden_size;
799
800 let size_in = cfg.hidden_size;
801 let size_q = cfg.head_dim * cfg.num_attention_heads;
802 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
803 let q_proj =
804 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
805 let k_proj =
806 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
807 let v_proj =
808 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
809 let o_proj =
810 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
811
812 let h_size = cfg.hidden_size;
813 let i_size = cfg.intermediate_size;
814 let gate_proj = h_size * i_size / weight_pack_factor;
815 let up_proj = h_size * i_size / weight_pack_factor;
816 let down_proj = i_size * h_size / weight_pack_factor;
817
818 input_layernorm
819 + post_attention_layernorm
820 + q_proj
821 + k_proj
822 + v_proj
823 + o_proj
824 + gate_proj
825 + up_proj
826 + down_proj
827 };
828 Ok(vec![
829 per_layer_elems * dtype.size_in_bytes();
830 cfg.num_hidden_layers
831 ])
832 }
833
834 fn num_layers(&self, config: &str) -> Result<usize> {
835 let cfg = GemmaBasicConfig::deserialize(config, false)?;
836 Ok(cfg.num_hidden_layers)
837 }
838
839 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
840 let cfg = GemmaBasicConfig::deserialize(config, false)?;
841
842 let cfg = ModelConfigMetadata {
843 max_seq_len: cfg.max_position_embeddings,
844 num_layers: cfg.num_hidden_layers,
845 hidden_size: cfg.hidden_size,
846 num_kv_heads: cfg.num_key_value_heads,
847 num_attn_heads: cfg.num_attention_heads,
848 sliding_window: None,
849 k_head_dim: cfg.head_dim,
850 v_head_dim: cfg.head_dim,
851 };
852
853 Ok(Box::new(cfg))
854 }
855}
856
857#[derive(Deserialize)]
860struct LlamaBasicConfig {
861 hidden_act: Activation,
862 hidden_size: usize,
863 intermediate_size: usize,
864 vocab_size: usize,
865 num_hidden_layers: usize,
866 num_attention_heads: usize,
867 num_key_value_heads: Option<usize>,
868 rms_norm_eps: f64,
869 #[serde(default = "default_rope")]
870 rope_theta: f32,
871 max_position_embeddings: usize,
872 rope_scaling: Option<Llama3RopeConfig>,
873 quantization_config: Option<QuantizedConfig>,
874 #[serde(default = "word_emb_default")]
875 tie_word_embeddings: bool,
876}
877
878fn default_rope() -> f32 {
879 10_000.0
880}
881
882impl LlamaBasicConfig {
883 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::llama::Config> {
884 let basic_config: Self = serde_json::from_str(slice)?;
885 Ok(models::llama::Config {
886 hidden_size: basic_config.hidden_size,
887 intermediate_size: basic_config.intermediate_size,
888 vocab_size: basic_config.vocab_size,
889 num_hidden_layers: basic_config.num_hidden_layers,
890 num_attention_heads: basic_config.num_attention_heads,
891 num_key_value_heads: basic_config
892 .num_key_value_heads
893 .unwrap_or(basic_config.num_attention_heads),
894 rms_norm_eps: basic_config.rms_norm_eps,
895 rope_theta: basic_config.rope_theta,
896 use_flash_attn,
897 max_position_embeddings: basic_config.max_position_embeddings,
898 rope_scaling: basic_config.rope_scaling,
899 quantization_config: basic_config.quantization_config,
900 tie_word_embeddings: basic_config.tie_word_embeddings,
901 hidden_act: basic_config.hidden_act,
902 })
903 }
904}
905
906pub struct LlamaLoader;
910
911impl NormalModelLoader for LlamaLoader {
912 fn load(
913 &self,
914 config: &str,
915 use_flash_attn: bool,
916 vb: ShardedVarBuilder,
917 normal_loading_metadata: NormalLoadingMetadata,
918 attention_mechanism: AttentionImplementation,
919 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
920 Ok(Box::new(models::llama::Llama::new(
921 &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
922 vb,
923 self.is_gptx(config)?,
924 normal_loading_metadata,
925 attention_mechanism,
926 )?))
927 }
928 fn load_xlora(
929 &self,
930 config: &str,
931 use_flash_attn: bool,
932 vb: ShardedVarBuilder,
933 lora_config: &[((String, String), LoraConfig)],
934 xlora_config: Option<XLoraConfig>,
935 xlora_ordering: Ordering,
936 normal_loading_metadata: NormalLoadingMetadata,
937 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
938 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
939 Ok(Box::new(xlora_models::XLoraLlama::new(
940 &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
941 vb,
942 lora_config,
943 xlora_config,
944 xlora_ordering,
945 self.is_gptx(config)?,
946 normal_loading_metadata,
947 preload_adapters,
948 )?))
949 }
950 fn is_gptx(&self, _: &str) -> Result<bool> {
951 Ok(true)
952 }
953 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
954 Ok(Box::new(LlamaBasicConfig::deserialize(
955 config,
956 use_flash_attn,
957 )?))
958 }
959}
960
961impl IsqModelLoader for LlamaLoader {
962 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
963 Ok(vec![
964 Regex::new(r"lm_head\.(weight|bias)$")?,
965 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
967 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
968 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
969 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
970 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
972 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
973 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
974 ])
975 }
976}
977
978impl DeviceMappedModelLoader for LlamaLoader {
979 fn mapped_max_act_size_elems(
980 &self,
981 config: &str,
982 params: &AutoDeviceMapParams,
983 prompt_chunksize: usize,
984 ) -> Result<usize> {
985 let AutoDeviceMapParams::Text {
986 max_seq_len: _,
987 max_batch_size,
988 } = params
989 else {
990 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
991 };
992
993 let cfg = LlamaBasicConfig::deserialize(config, false)?;
994
995 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
996 }
997 fn non_mapped_max_act_size_elems(
998 &self,
999 _config: &str,
1000 _params: &AutoDeviceMapParams,
1001 ) -> Result<usize> {
1002 Ok(0)
1003 }
1004
1005 fn non_mapped_size_in_bytes(
1006 &self,
1007 config: &str,
1008 dtype: DType,
1009 weight_pack_factor: usize,
1010 ) -> Result<usize> {
1011 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1012 let elems = {
1013 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1014 let lm_head = if !cfg.tie_word_embeddings {
1015 cfg.hidden_size * cfg.vocab_size
1016 } else {
1017 0
1018 };
1019 let norm = cfg.hidden_size;
1020 embed_tokens + lm_head + norm
1021 };
1022 Ok(elems * dtype.size_in_bytes())
1023 }
1024
1025 fn layer_sizes_in_bytes(
1026 &self,
1027 config: &str,
1028 dtype: DType,
1029 weight_pack_factor: usize,
1030 ) -> Result<Vec<usize>> {
1031 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1032 let per_layer_elems = {
1033 let input_layernorm = cfg.hidden_size;
1034 let post_attention_layernorm = cfg.hidden_size;
1035
1036 let size_in = cfg.hidden_size;
1037 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1038 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1039 let q_proj = size_in * size_q / weight_pack_factor;
1040 let k_proj = size_in * size_kv / weight_pack_factor;
1041 let v_proj = size_in * size_kv / weight_pack_factor;
1042 let o_proj = size_q * size_in / weight_pack_factor;
1043
1044 let h_size = cfg.hidden_size;
1045 let i_size = cfg.intermediate_size;
1046 let gate_proj = h_size * i_size / weight_pack_factor;
1047 let up_proj = h_size * i_size / weight_pack_factor;
1048 let down_proj = i_size * h_size / weight_pack_factor;
1049
1050 input_layernorm
1051 + post_attention_layernorm
1052 + q_proj
1053 + k_proj
1054 + v_proj
1055 + o_proj
1056 + gate_proj
1057 + up_proj
1058 + down_proj
1059 };
1060 Ok(vec![
1061 per_layer_elems * dtype.size_in_bytes();
1062 cfg.num_hidden_layers
1063 ])
1064 }
1065
1066 fn num_layers(&self, config: &str) -> Result<usize> {
1067 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1068 Ok(cfg.num_hidden_layers)
1069 }
1070 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1071 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1072
1073 let cfg = ModelConfigMetadata {
1074 max_seq_len: cfg.max_position_embeddings,
1075 num_layers: cfg.num_hidden_layers,
1076 hidden_size: cfg.hidden_size,
1077 num_kv_heads: cfg.num_key_value_heads,
1078 num_attn_heads: cfg.num_attention_heads,
1079 sliding_window: None,
1080 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1081 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1082 };
1083
1084 Ok(Box::new(cfg))
1085 }
1086}
1087
1088#[derive(Deserialize)]
1091struct MixtralBasicConfig {
1092 vocab_size: usize,
1093 hidden_size: usize,
1094 intermediate_size: usize,
1095 num_hidden_layers: usize,
1096 num_attention_heads: usize,
1097 num_key_value_heads: usize,
1098 hidden_act: Activation,
1099 max_position_embeddings: usize,
1100 rms_norm_eps: f64,
1101 rope_theta: f64,
1102 sliding_window: Option<usize>,
1103 num_experts_per_tok: usize,
1104 num_local_experts: usize,
1105 quantization_config: Option<QuantizedConfig>,
1106 #[serde(default = "word_emb_default")]
1107 tie_word_embeddings: bool,
1108}
1109
1110impl MixtralBasicConfig {
1111 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mixtral::Config> {
1112 let basic_config: Self = serde_json::from_str(slice)?;
1113 Ok(models::mixtral::Config {
1114 vocab_size: basic_config.vocab_size,
1115 hidden_size: basic_config.hidden_size,
1116 intermediate_size: basic_config.intermediate_size,
1117 num_hidden_layers: basic_config.num_hidden_layers,
1118 num_attention_heads: basic_config.num_attention_heads,
1119 num_key_value_heads: basic_config.num_key_value_heads,
1120 hidden_act: basic_config.hidden_act,
1121 max_position_embeddings: basic_config.max_position_embeddings,
1122 rms_norm_eps: basic_config.rms_norm_eps,
1123 rope_theta: basic_config.rope_theta,
1124 sliding_window: basic_config.sliding_window,
1125 use_flash_attn,
1126 num_experts_per_tok: basic_config.num_experts_per_tok,
1127 num_local_experts: basic_config.num_local_experts,
1128 quantization_config: basic_config.quantization_config,
1129 tie_word_embeddings: basic_config.tie_word_embeddings,
1130 })
1131 }
1132}
1133
1134pub struct MixtralLoader;
1135
1136impl NormalModelLoader for MixtralLoader {
1137 fn load(
1138 &self,
1139 config: &str,
1140 use_flash_attn: bool,
1141 vb: ShardedVarBuilder,
1142 normal_loading_metadata: NormalLoadingMetadata,
1143 attention_mechanism: AttentionImplementation,
1144 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1145 Ok(Box::new(models::mixtral::Model::new(
1146 &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1147 vb,
1148 self.is_gptx(config)?,
1149 normal_loading_metadata,
1150 attention_mechanism,
1151 )?))
1152 }
1153 fn load_xlora(
1154 &self,
1155 config: &str,
1156 use_flash_attn: bool,
1157 vb: ShardedVarBuilder,
1158 lora_config: &[((String, String), LoraConfig)],
1159 xlora_config: Option<XLoraConfig>,
1160 xlora_ordering: Ordering,
1161 normal_loading_metadata: NormalLoadingMetadata,
1162 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1163 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1164 Ok(Box::new(xlora_models::XLoraMixtral::new(
1165 &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1166 vb,
1167 lora_config,
1168 xlora_config,
1169 xlora_ordering,
1170 self.is_gptx(config)?,
1171 normal_loading_metadata,
1172 preload_adapters,
1173 )?))
1174 }
1175 fn is_gptx(&self, _: &str) -> Result<bool> {
1176 Ok(true)
1177 }
1178 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1179 Ok(Box::new(MixtralBasicConfig::deserialize(
1180 config,
1181 use_flash_attn,
1182 )?))
1183 }
1184}
1185
1186impl IsqModelLoader for MixtralLoader {
1187 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1188 Ok(vec![
1189 Regex::new(r"lm_head\.(weight|bias)$")?,
1190 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1192 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1193 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1194 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1195 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1197 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1198 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1199 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1200 ])
1201 }
1202}
1203
1204impl DeviceMappedModelLoader for MixtralLoader {
1205 fn mapped_max_act_size_elems(
1206 &self,
1207 config: &str,
1208 params: &AutoDeviceMapParams,
1209 prompt_chunksize: usize,
1210 ) -> Result<usize> {
1211 let AutoDeviceMapParams::Text {
1212 max_seq_len: _,
1213 max_batch_size,
1214 } = params
1215 else {
1216 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1217 };
1218
1219 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1220
1221 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1222 }
1223 fn non_mapped_max_act_size_elems(
1224 &self,
1225 _config: &str,
1226 _params: &AutoDeviceMapParams,
1227 ) -> Result<usize> {
1228 Ok(0)
1229 }
1230
1231 fn non_mapped_size_in_bytes(
1232 &self,
1233 config: &str,
1234 dtype: DType,
1235 weight_pack_factor: usize,
1236 ) -> Result<usize> {
1237 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1238 let elems = {
1239 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1240 let lm_head = if !cfg.tie_word_embeddings {
1241 cfg.hidden_size * cfg.vocab_size
1242 } else {
1243 0
1244 };
1245 let norm = cfg.hidden_size;
1246 embed_tokens + lm_head + norm
1247 };
1248 Ok(elems * dtype.size_in_bytes())
1249 }
1250
1251 fn layer_sizes_in_bytes(
1252 &self,
1253 config: &str,
1254 dtype: DType,
1255 weight_pack_factor: usize,
1256 ) -> Result<Vec<usize>> {
1257 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1258 let per_layer_elems = {
1259 let input_layernorm = cfg.hidden_size;
1260 let post_attention_layernorm = cfg.hidden_size;
1261
1262 let size_in = cfg.hidden_size;
1263 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1264 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1265 let q_proj = size_in * size_q / weight_pack_factor;
1266 let k_proj = size_in * size_kv / weight_pack_factor;
1267 let v_proj = size_in * size_kv / weight_pack_factor;
1268 let o_proj = size_q * size_in / weight_pack_factor;
1269
1270 let moe_block = {
1271 let gate = cfg.hidden_size * cfg.num_local_experts;
1272 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1274 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1275 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1276 gate + cfg.num_local_experts * w1
1277 + cfg.num_local_experts * w2
1278 + cfg.num_local_experts * w3
1279 };
1280
1281 input_layernorm
1282 + post_attention_layernorm
1283 + q_proj
1284 + k_proj
1285 + v_proj
1286 + o_proj
1287 + moe_block
1288 };
1289 Ok(vec![
1290 per_layer_elems * dtype.size_in_bytes();
1291 cfg.num_hidden_layers
1292 ])
1293 }
1294
1295 fn num_layers(&self, config: &str) -> Result<usize> {
1296 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1297 Ok(cfg.num_hidden_layers)
1298 }
1299
1300 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1301 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1302
1303 let cfg = ModelConfigMetadata {
1304 max_seq_len: cfg.max_position_embeddings,
1305 num_layers: cfg.num_hidden_layers,
1306 hidden_size: cfg.hidden_size,
1307 num_kv_heads: cfg.num_key_value_heads,
1308 num_attn_heads: cfg.num_attention_heads,
1309 sliding_window: cfg.sliding_window,
1310 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1311 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1312 };
1313
1314 Ok(Box::new(cfg))
1315 }
1316}
1317
1318#[derive(Deserialize)]
1321struct Phi2BasicConfig {
1322 vocab_size: usize,
1323 hidden_size: usize,
1324 intermediate_size: usize,
1325 num_hidden_layers: usize,
1326 num_attention_heads: usize,
1327 num_key_value_heads: Option<usize>,
1328 hidden_act: Activation,
1329 max_position_embeddings: usize,
1330 layer_norm_eps: f64,
1331 rope_theta: f32,
1332 partial_rotary_factor: f64,
1333 qk_layernorm: bool,
1334 quantization_config: Option<QuantizedConfig>,
1335 #[serde(default = "word_emb_default")]
1336 tie_word_embeddings: bool,
1337}
1338
1339impl Phi2BasicConfig {
1340 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi2::Config> {
1341 let basic_config: Self = serde_json::from_str(slice)?;
1342 Ok(models::phi2::Config {
1343 vocab_size: basic_config.vocab_size,
1344 hidden_size: basic_config.hidden_size,
1345 intermediate_size: basic_config.intermediate_size,
1346 num_hidden_layers: basic_config.num_hidden_layers,
1347 num_attention_heads: basic_config.num_attention_heads,
1348 num_key_value_heads: basic_config.num_key_value_heads,
1349 hidden_act: basic_config.hidden_act,
1350 max_position_embeddings: basic_config.max_position_embeddings,
1351 rope_theta: basic_config.rope_theta,
1352 layer_norm_eps: basic_config.layer_norm_eps,
1353 partial_rotary_factor: basic_config.partial_rotary_factor,
1354 qk_layernorm: basic_config.qk_layernorm,
1355 use_flash_attn,
1356 quantization_config: basic_config.quantization_config,
1357 tie_word_embeddings: basic_config.tie_word_embeddings,
1358 })
1359 }
1360}
1361
1362pub struct Phi2Loader;
1366
1367impl NormalModelLoader for Phi2Loader {
1368 fn load(
1369 &self,
1370 config: &str,
1371 use_flash_attn: bool,
1372 vb: ShardedVarBuilder,
1373 normal_loading_metadata: NormalLoadingMetadata,
1374 attention_mechanism: AttentionImplementation,
1375 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1376 Ok(Box::new(models::phi2::Model::new(
1377 &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1378 vb,
1379 self.is_gptx(config)?,
1380 normal_loading_metadata,
1381 attention_mechanism,
1382 )?))
1383 }
1384 fn load_xlora(
1385 &self,
1386 config: &str,
1387 use_flash_attn: bool,
1388 vb: ShardedVarBuilder,
1389 lora_config: &[((String, String), LoraConfig)],
1390 xlora_config: Option<XLoraConfig>,
1391 xlora_ordering: Ordering,
1392 normal_loading_metadata: NormalLoadingMetadata,
1393 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1394 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1395 Ok(Box::new(xlora_models::XLoraPhi2::new(
1396 &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1397 vb,
1398 lora_config,
1399 xlora_config,
1400 xlora_ordering,
1401 self.is_gptx(config)?,
1402 normal_loading_metadata,
1403 preload_adapters,
1404 )?))
1405 }
1406 fn is_gptx(&self, _: &str) -> Result<bool> {
1407 Ok(true)
1408 }
1409 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1410 Ok(Box::new(Phi2BasicConfig::deserialize(
1411 config,
1412 use_flash_attn,
1413 )?))
1414 }
1415}
1416
1417impl IsqModelLoader for Phi2Loader {
1418 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1419 Ok(vec![
1420 Regex::new(r"lm_head\.(weight|bias)$")?,
1421 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1423 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1424 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1425 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1426 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1428 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1429 ])
1430 }
1431}
1432
1433impl DeviceMappedModelLoader for Phi2Loader {
1434 fn mapped_max_act_size_elems(
1435 &self,
1436 config: &str,
1437 params: &AutoDeviceMapParams,
1438 prompt_chunksize: usize,
1439 ) -> Result<usize> {
1440 let AutoDeviceMapParams::Text {
1441 max_seq_len: _,
1442 max_batch_size,
1443 } = params
1444 else {
1445 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1446 };
1447
1448 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1449
1450 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1451 }
1452 fn non_mapped_max_act_size_elems(
1453 &self,
1454 _config: &str,
1455 _params: &AutoDeviceMapParams,
1456 ) -> Result<usize> {
1457 Ok(0)
1458 }
1459
1460 fn non_mapped_size_in_bytes(
1461 &self,
1462 config: &str,
1463 dtype: DType,
1464 weight_pack_factor: usize,
1465 ) -> Result<usize> {
1466 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1467 let elems = {
1468 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1469 let lm_head = if !cfg.tie_word_embeddings {
1470 cfg.hidden_size * cfg.vocab_size
1471 } else {
1472 0
1473 };
1474 let norm = cfg.hidden_size;
1475 embed_tokens + lm_head + norm
1476 };
1477 Ok(elems * dtype.size_in_bytes())
1478 }
1479
1480 fn layer_sizes_in_bytes(
1481 &self,
1482 config: &str,
1483 dtype: DType,
1484 weight_pack_factor: usize,
1485 ) -> Result<Vec<usize>> {
1486 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1487 let per_layer_elems = {
1488 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1489
1490 let size_in = cfg.hidden_size;
1491 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1492 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1493 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1494 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1495 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1496 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1497 let (q_norm, k_norm) = if cfg.qk_layernorm {
1498 (cfg.head_dim(), cfg.head_dim())
1499 } else {
1500 (0, 0)
1501 };
1502
1503 let h_size = cfg.hidden_size;
1504 let i_size = cfg.intermediate_size;
1505 let fc1 = h_size * i_size / weight_pack_factor;
1506 let fc2 = h_size * i_size / weight_pack_factor;
1507
1508 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1509 };
1510 Ok(vec![
1511 per_layer_elems * dtype.size_in_bytes();
1512 cfg.num_hidden_layers
1513 ])
1514 }
1515
1516 fn num_layers(&self, config: &str) -> Result<usize> {
1517 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1518 Ok(cfg.num_hidden_layers)
1519 }
1520
1521 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1522 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1523
1524 let cfg = ModelConfigMetadata {
1525 max_seq_len: cfg.max_position_embeddings,
1526 num_layers: cfg.num_hidden_layers,
1527 hidden_size: cfg.hidden_size,
1528 num_kv_heads: cfg.num_key_value_heads(),
1529 num_attn_heads: cfg.num_attention_heads,
1530 sliding_window: None,
1531 k_head_dim: cfg.head_dim(),
1532 v_head_dim: cfg.head_dim(),
1533 };
1534
1535 Ok(Box::new(cfg))
1536 }
1537}
1538
1539#[derive(Deserialize)]
1542struct Phi3BasicConfig {
1543 vocab_size: usize,
1544 hidden_act: Activation,
1545 hidden_size: usize,
1546 intermediate_size: usize,
1547 num_hidden_layers: usize,
1548 num_attention_heads: usize,
1549 num_key_value_heads: usize,
1550 rms_norm_eps: f64,
1551 rope_theta: f64,
1552 bos_token_id: Option<u32>,
1553 eos_token_id: Option<u32>,
1554 rope_scaling: Option<PhiRopeScalingConfig>,
1555 max_position_embeddings: usize,
1556 original_max_position_embeddings: usize,
1557 sliding_window: Option<usize>,
1558 quantization_config: Option<QuantizedConfig>,
1559 #[serde(default = "word_emb_default")]
1560 tie_word_embeddings: bool,
1561 partial_rotary_factor: Option<f64>,
1562}
1563
1564impl Phi3BasicConfig {
1565 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3::Config> {
1566 let basic_config: Self = serde_json::from_str(slice)?;
1567 Ok(models::phi3::Config {
1568 vocab_size: basic_config.vocab_size,
1569 hidden_size: basic_config.hidden_size,
1570 intermediate_size: basic_config.intermediate_size,
1571 num_hidden_layers: basic_config.num_hidden_layers,
1572 num_attention_heads: basic_config.num_attention_heads,
1573 num_key_value_heads: basic_config.num_key_value_heads,
1574 hidden_act: basic_config.hidden_act,
1575 max_position_embeddings: basic_config.max_position_embeddings,
1576 rope_theta: basic_config.rope_theta,
1577 rms_norm_eps: basic_config.rms_norm_eps,
1578 eos_token_id: basic_config.eos_token_id,
1579 bos_token_id: basic_config.bos_token_id,
1580 rope_scaling: basic_config.rope_scaling,
1581 original_max_position_embeddings: basic_config.original_max_position_embeddings,
1582 use_flash_attn,
1583 sliding_window: basic_config.sliding_window,
1584 quantization_config: basic_config.quantization_config,
1585 tie_word_embeddings: basic_config.tie_word_embeddings,
1586 partial_rotary_factor: basic_config.partial_rotary_factor,
1587 })
1588 }
1589}
1590
1591pub struct Phi3Loader;
1595
1596impl NormalModelLoader for Phi3Loader {
1597 fn load(
1598 &self,
1599 config: &str,
1600 use_flash_attn: bool,
1601 vb: ShardedVarBuilder,
1602 normal_loading_metadata: NormalLoadingMetadata,
1603 attention_mechanism: AttentionImplementation,
1604 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1605 Ok(Box::new(models::phi3::Model::new(
1606 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1607 vb,
1608 self.is_gptx(config)?,
1609 normal_loading_metadata,
1610 attention_mechanism,
1611 )?))
1612 }
1613 fn load_xlora(
1614 &self,
1615 config: &str,
1616 use_flash_attn: bool,
1617 vb: ShardedVarBuilder,
1618 lora_config: &[((String, String), LoraConfig)],
1619 xlora_config: Option<XLoraConfig>,
1620 xlora_ordering: Ordering,
1621 normal_loading_metadata: NormalLoadingMetadata,
1622 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1623 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1624 Ok(Box::new(xlora_models::XLoraPhi3::new(
1625 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1626 vb,
1627 lora_config,
1628 xlora_config,
1629 xlora_ordering,
1630 self.is_gptx(config)?,
1631 normal_loading_metadata,
1632 preload_adapters,
1633 )?))
1634 }
1635 fn is_gptx(&self, _: &str) -> Result<bool> {
1636 Ok(true)
1637 }
1638 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1639 Ok(Box::new(Phi3BasicConfig::deserialize(
1640 config,
1641 use_flash_attn,
1642 )?))
1643 }
1644}
1645
1646impl IsqModelLoader for Phi3Loader {
1647 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1648 Ok(vec![
1649 Regex::new(r"lm_head\.(weight|bias)$")?,
1650 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1652 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1653 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1655 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1656 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1657 ])
1658 }
1659}
1660
1661impl DeviceMappedModelLoader for Phi3Loader {
1662 fn mapped_max_act_size_elems(
1663 &self,
1664 config: &str,
1665 params: &AutoDeviceMapParams,
1666 prompt_chunksize: usize,
1667 ) -> Result<usize> {
1668 let AutoDeviceMapParams::Text {
1669 max_seq_len: _,
1670 max_batch_size,
1671 } = params
1672 else {
1673 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1674 };
1675
1676 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1677
1678 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1679 }
1680 fn non_mapped_max_act_size_elems(
1681 &self,
1682 _config: &str,
1683 _params: &AutoDeviceMapParams,
1684 ) -> Result<usize> {
1685 Ok(0)
1686 }
1687
1688 fn non_mapped_size_in_bytes(
1689 &self,
1690 config: &str,
1691 dtype: DType,
1692 weight_pack_factor: usize,
1693 ) -> Result<usize> {
1694 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1695 let elems = {
1696 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1697 let lm_head = if !cfg.tie_word_embeddings {
1698 cfg.hidden_size * cfg.vocab_size
1699 } else {
1700 0
1701 };
1702 let norm = cfg.hidden_size;
1703 embed_tokens + lm_head + norm
1704 };
1705 Ok(elems * dtype.size_in_bytes())
1706 }
1707
1708 fn layer_sizes_in_bytes(
1709 &self,
1710 config: &str,
1711 dtype: DType,
1712 weight_pack_factor: usize,
1713 ) -> Result<Vec<usize>> {
1714 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1715 let per_layer_elems = {
1716 let input_layernorm = cfg.hidden_size;
1717 let post_attention_layernorm = cfg.hidden_size;
1718
1719 let size_in = cfg.hidden_size;
1720 let head_dim = cfg.head_dim();
1721 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1722 let qkv_proj = size_in * op_size / weight_pack_factor;
1723 let o_proj =
1724 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1725
1726 let h_size = cfg.hidden_size;
1727 let i_size = cfg.intermediate_size;
1728 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1729 let down_proj = h_size * i_size / weight_pack_factor;
1730
1731 input_layernorm
1732 + post_attention_layernorm
1733 + qkv_proj
1734 + o_proj
1735 + gate_up_proj
1736 + down_proj
1737 };
1738 Ok(vec![
1739 per_layer_elems * dtype.size_in_bytes();
1740 cfg.num_hidden_layers
1741 ])
1742 }
1743
1744 fn num_layers(&self, config: &str) -> Result<usize> {
1745 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1746 Ok(cfg.num_hidden_layers)
1747 }
1748
1749 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1750 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1751
1752 let cfg = ModelConfigMetadata {
1753 max_seq_len: cfg.max_position_embeddings,
1754 num_layers: cfg.num_hidden_layers,
1755 hidden_size: cfg.hidden_size,
1756 num_kv_heads: cfg.num_key_value_heads,
1757 num_attn_heads: cfg.num_attention_heads,
1758 sliding_window: cfg.sliding_window,
1759 k_head_dim: cfg.head_dim(),
1760 v_head_dim: cfg.head_dim(),
1761 };
1762
1763 Ok(Box::new(cfg))
1764 }
1765}
1766
1767#[derive(Deserialize)]
1770struct Qwen2BasicConfig {
1771 vocab_size: usize,
1772 hidden_size: usize,
1773 intermediate_size: usize,
1774 num_hidden_layers: usize,
1775 num_attention_heads: usize,
1776 num_key_value_heads: usize,
1777 max_position_embeddings: usize,
1778 sliding_window: Option<usize>,
1779 rope_theta: f64,
1780 rms_norm_eps: f64,
1781 hidden_act: Activation,
1782 quantization_config: Option<QuantizedConfig>,
1783 tie_word_embeddings: bool,
1784}
1785
1786impl Qwen2BasicConfig {
1787 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::qwen2::Config> {
1788 let basic_config: Self = serde_json::from_str(slice)?;
1789 Ok(models::qwen2::Config {
1790 vocab_size: basic_config.vocab_size,
1791 hidden_size: basic_config.hidden_size,
1792 intermediate_size: basic_config.intermediate_size,
1793 num_hidden_layers: basic_config.num_hidden_layers,
1794 num_attention_heads: basic_config.num_attention_heads,
1795 num_key_value_heads: basic_config.num_key_value_heads,
1796 hidden_act: basic_config.hidden_act,
1797 max_position_embeddings: basic_config.max_position_embeddings,
1798 rope_theta: basic_config.rope_theta,
1799 rms_norm_eps: basic_config.rms_norm_eps,
1800 sliding_window: basic_config.sliding_window,
1801 use_flash_attn,
1802 quantization_config: basic_config.quantization_config,
1803 tie_word_embeddings: basic_config.tie_word_embeddings,
1804 })
1805 }
1806}
1807
1808pub struct Qwen2Loader;
1812
1813impl NormalModelLoader for Qwen2Loader {
1814 fn load(
1815 &self,
1816 config: &str,
1817 use_flash_attn: bool,
1818 vb: ShardedVarBuilder,
1819 normal_loading_metadata: NormalLoadingMetadata,
1820 attention_mechanism: AttentionImplementation,
1821 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1822 Ok(Box::new(models::qwen2::Model::new(
1823 &Qwen2BasicConfig::deserialize(config, use_flash_attn)?,
1824 vb,
1825 self.is_gptx(config)?,
1826 normal_loading_metadata,
1827 attention_mechanism,
1828 )?))
1829 }
1830 fn load_xlora(
1831 &self,
1832 _config: &str,
1833 _use_flash_attn: bool,
1834 _vb: ShardedVarBuilder,
1835 _lora_config: &[((String, String), LoraConfig)],
1836 _xlora_config: Option<XLoraConfig>,
1837 _xlora_ordering: Ordering,
1838 _normal_loading_metadata: NormalLoadingMetadata,
1839 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1840 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1841 todo!()
1842 }
1843 fn is_gptx(&self, _: &str) -> Result<bool> {
1844 Ok(true)
1845 }
1846 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1847 Ok(Box::new(Qwen2BasicConfig::deserialize(
1848 config,
1849 use_flash_attn,
1850 )?))
1851 }
1852}
1853
1854impl IsqModelLoader for Qwen2Loader {
1855 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1856 Ok(vec![
1857 Regex::new(r"lm_head\.(weight|bias)$")?,
1858 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1860 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1861 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1862 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1863 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1865 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1866 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1867 ])
1868 }
1869}
1870
1871impl DeviceMappedModelLoader for Qwen2Loader {
1872 fn mapped_max_act_size_elems(
1873 &self,
1874 config: &str,
1875 params: &AutoDeviceMapParams,
1876 prompt_chunksize: usize,
1877 ) -> Result<usize> {
1878 let AutoDeviceMapParams::Text {
1879 max_seq_len: _,
1880 max_batch_size,
1881 } = params
1882 else {
1883 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1884 };
1885
1886 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1887
1888 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1889 }
1890 fn non_mapped_max_act_size_elems(
1891 &self,
1892 _config: &str,
1893 _params: &AutoDeviceMapParams,
1894 ) -> Result<usize> {
1895 Ok(0)
1896 }
1897
1898 fn non_mapped_size_in_bytes(
1899 &self,
1900 config: &str,
1901 dtype: DType,
1902 weight_pack_factor: usize,
1903 ) -> Result<usize> {
1904 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1905 let elems = {
1906 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1907 let lm_head = if !cfg.tie_word_embeddings {
1908 cfg.hidden_size * cfg.vocab_size
1909 } else {
1910 0
1911 };
1912 let norm = cfg.hidden_size;
1913 embed_tokens + lm_head + norm
1914 };
1915 Ok(elems * dtype.size_in_bytes())
1916 }
1917
1918 fn layer_sizes_in_bytes(
1919 &self,
1920 config: &str,
1921 dtype: DType,
1922 weight_pack_factor: usize,
1923 ) -> Result<Vec<usize>> {
1924 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1925 let per_layer_elems = {
1926 let input_layernorm = cfg.hidden_size;
1927 let post_attention_layernorm = cfg.hidden_size;
1928
1929 let size_in = cfg.hidden_size;
1930 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1931 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1932 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1933 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1934 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1935 let o_proj = size_q * size_in / weight_pack_factor;
1936
1937 let h_size = cfg.hidden_size;
1938 let i_size = cfg.intermediate_size;
1939 let gate_proj = h_size * i_size / weight_pack_factor;
1940 let up_proj = h_size * i_size / weight_pack_factor;
1941 let down_proj = i_size * h_size / weight_pack_factor;
1942
1943 input_layernorm
1944 + post_attention_layernorm
1945 + q_proj
1946 + k_proj
1947 + v_proj
1948 + o_proj
1949 + gate_proj
1950 + up_proj
1951 + down_proj
1952 };
1953 Ok(vec![
1954 per_layer_elems * dtype.size_in_bytes();
1955 cfg.num_hidden_layers
1956 ])
1957 }
1958
1959 fn num_layers(&self, config: &str) -> Result<usize> {
1960 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1961 Ok(cfg.num_hidden_layers)
1962 }
1963
1964 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1965 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1966
1967 let cfg = ModelConfigMetadata {
1968 max_seq_len: cfg.max_position_embeddings,
1969 num_layers: cfg.num_hidden_layers,
1970 hidden_size: cfg.hidden_size,
1971 num_kv_heads: cfg.num_key_value_heads,
1972 num_attn_heads: cfg.num_attention_heads,
1973 sliding_window: cfg.sliding_window,
1974 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1976 };
1977
1978 Ok(Box::new(cfg))
1979 }
1980}
1981
1982#[derive(Deserialize)]
1985struct Gemma2BasicConfig {
1986 attention_bias: bool,
1987 head_dim: usize,
1988 hidden_act: Option<Activation>,
1990 hidden_activation: Option<Activation>,
1991 hidden_size: usize,
1992 intermediate_size: usize,
1993 num_attention_heads: usize,
1994 num_hidden_layers: usize,
1995 num_key_value_heads: usize,
1996 rms_norm_eps: f64,
1997 rope_theta: f64,
1998 vocab_size: usize,
1999 sliding_window: usize,
2000 attn_logit_softcapping: Option<f64>,
2001 final_logit_softcapping: Option<f64>,
2002 query_pre_attn_scalar: usize,
2003
2004 #[serde(default = "default_max_position_embeddings")]
2005 max_position_embeddings: usize,
2006 quantization_config: Option<QuantizedConfig>,
2007 #[serde(default = "word_emb_default")]
2008 tie_word_embeddings: bool,
2009}
2010
2011impl Gemma2BasicConfig {
2012 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma2::Config> {
2013 let basic_config: Self = serde_json::from_str(slice)?;
2014 Ok(models::gemma2::Config {
2015 vocab_size: basic_config.vocab_size,
2016 hidden_size: basic_config.hidden_size,
2017 intermediate_size: basic_config.intermediate_size,
2018 num_hidden_layers: basic_config.num_hidden_layers,
2019 num_attention_heads: basic_config.num_attention_heads,
2020 num_key_value_heads: basic_config.num_key_value_heads,
2021 hidden_act: basic_config.hidden_act,
2022 hidden_activation: basic_config.hidden_activation,
2023 max_position_embeddings: basic_config.max_position_embeddings,
2024 rms_norm_eps: basic_config.rms_norm_eps,
2025 rope_theta: basic_config.rope_theta,
2026 attention_bias: basic_config.attention_bias,
2027 head_dim: basic_config.head_dim,
2028 use_flash_attn,
2029 quantization_config: basic_config.quantization_config,
2030 sliding_window: basic_config.sliding_window,
2031 attn_logit_softcapping: basic_config.attn_logit_softcapping,
2032 final_logit_softcapping: basic_config.final_logit_softcapping,
2033 query_pre_attn_scalar: basic_config.query_pre_attn_scalar,
2034 tie_word_embeddings: basic_config.tie_word_embeddings,
2035 })
2036 }
2037}
2038
2039pub struct Gemma2Loader;
2043
2044impl NormalModelLoader for Gemma2Loader {
2045 fn load(
2046 &self,
2047 config: &str,
2048 use_flash_attn: bool,
2049 vb: ShardedVarBuilder,
2050 normal_loading_metadata: NormalLoadingMetadata,
2051 attention_mechanism: AttentionImplementation,
2052 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2053 Ok(Box::new(models::gemma2::Model::new(
2054 &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2055 vb,
2056 self.is_gptx(config)?,
2057 normal_loading_metadata,
2058 attention_mechanism,
2059 )?))
2060 }
2061 fn load_xlora(
2062 &self,
2063 config: &str,
2064 use_flash_attn: bool,
2065 vb: ShardedVarBuilder,
2066 lora_config: &[((String, String), LoraConfig)],
2067 xlora_config: Option<XLoraConfig>,
2068 xlora_ordering: Ordering,
2069 normal_loading_metadata: NormalLoadingMetadata,
2070 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2071 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2072 Ok(Box::new(xlora_models::XLoraGemma2::new(
2073 &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2074 vb,
2075 lora_config,
2076 xlora_config,
2077 xlora_ordering,
2078 self.is_gptx(config)?,
2079 normal_loading_metadata,
2080 preload_adapters,
2081 )?))
2082 }
2083 fn is_gptx(&self, _: &str) -> Result<bool> {
2084 Ok(true)
2085 }
2086 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2087 Ok(Box::new(Gemma2BasicConfig::deserialize(
2089 config,
2090 use_flash_attn,
2091 )?))
2092 }
2093}
2094
2095impl IsqModelLoader for Gemma2Loader {
2096 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2097 Ok(vec![
2098 Regex::new(r"lm_head\.(weight|bias)$")?,
2099 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2101 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2102 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2103 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2104 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2106 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2107 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2108 ])
2109 }
2110}
2111
2112impl DeviceMappedModelLoader for Gemma2Loader {
2113 fn mapped_max_act_size_elems(
2114 &self,
2115 config: &str,
2116 params: &AutoDeviceMapParams,
2117 prompt_chunksize: usize,
2118 ) -> Result<usize> {
2119 let AutoDeviceMapParams::Text {
2120 max_seq_len: _,
2121 max_batch_size,
2122 } = params
2123 else {
2124 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2125 };
2126
2127 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2128
2129 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2130 }
2131 fn non_mapped_max_act_size_elems(
2132 &self,
2133 _config: &str,
2134 _params: &AutoDeviceMapParams,
2135 ) -> Result<usize> {
2136 Ok(0)
2137 }
2138
2139 fn non_mapped_size_in_bytes(
2140 &self,
2141 config: &str,
2142 dtype: DType,
2143 weight_pack_factor: usize,
2144 ) -> Result<usize> {
2145 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2146 let elems = {
2147 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2148 let lm_head = if !cfg.tie_word_embeddings {
2149 cfg.hidden_size * cfg.vocab_size
2150 } else {
2151 0
2152 };
2153 let norm = cfg.hidden_size;
2154 embed_tokens + lm_head + norm
2155 };
2156 Ok(elems * dtype.size_in_bytes())
2157 }
2158
2159 fn layer_sizes_in_bytes(
2160 &self,
2161 config: &str,
2162 dtype: DType,
2163 weight_pack_factor: usize,
2164 ) -> Result<Vec<usize>> {
2165 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2166 let per_layer_elems = {
2167 let input_layernorm = cfg.hidden_size;
2168 let post_attention_layernorm = cfg.hidden_size;
2169
2170 let size_in = cfg.hidden_size;
2171 let size_q = cfg.head_dim * cfg.num_attention_heads;
2172 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
2173 let q_proj =
2174 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2175 let k_proj =
2176 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2177 let v_proj =
2178 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2179 let o_proj =
2180 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2181
2182 let h_size = cfg.hidden_size;
2183 let i_size = cfg.intermediate_size;
2184 let gate_proj = h_size * i_size / weight_pack_factor;
2185 let up_proj = h_size * i_size / weight_pack_factor;
2186 let down_proj = i_size * h_size / weight_pack_factor;
2187
2188 input_layernorm
2189 + post_attention_layernorm
2190 + q_proj
2191 + k_proj
2192 + v_proj
2193 + o_proj
2194 + gate_proj
2195 + up_proj
2196 + down_proj
2197 };
2198 Ok(vec![
2199 per_layer_elems * dtype.size_in_bytes();
2200 cfg.num_hidden_layers
2201 ])
2202 }
2203
2204 fn num_layers(&self, config: &str) -> Result<usize> {
2205 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2206 Ok(cfg.num_hidden_layers)
2207 }
2208 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2209 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2210
2211 let cfg = ModelConfigMetadata {
2212 max_seq_len: cfg.max_position_embeddings,
2213 num_layers: cfg.num_hidden_layers,
2214 hidden_size: cfg.hidden_size,
2215 num_kv_heads: cfg.num_key_value_heads,
2216 num_attn_heads: cfg.num_attention_heads,
2217 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2219 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2220 };
2221
2222 Ok(Box::new(cfg))
2223 }
2224}
2225
2226#[derive(Deserialize, Debug)]
2229struct Starcoder2BasicConfig {
2230 vocab_size: usize,
2231 hidden_size: usize,
2232 intermediate_size: usize,
2233 num_hidden_layers: usize,
2234 num_attention_heads: usize,
2235 num_key_value_heads: usize,
2236 hidden_act: Activation,
2237 max_position_embeddings: usize,
2238 norm_epsilon: f64,
2239 rope_theta: f64,
2240 use_bias: bool,
2241 sliding_window: Option<usize>,
2242 quantization_config: Option<QuantizedConfig>,
2243 #[serde(default = "word_emb_default")]
2244 tie_word_embeddings: bool,
2245}
2246
2247impl Starcoder2BasicConfig {
2248 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::starcoder2::Config> {
2249 let basic_config: Self = serde_json::from_str(slice)?;
2250 Ok(models::starcoder2::Config {
2251 vocab_size: basic_config.vocab_size,
2252 hidden_size: basic_config.hidden_size,
2253 intermediate_size: basic_config.intermediate_size,
2254 num_hidden_layers: basic_config.num_hidden_layers,
2255 num_attention_heads: basic_config.num_attention_heads,
2256 num_key_value_heads: basic_config.num_key_value_heads,
2257 hidden_act: basic_config.hidden_act,
2258 max_position_embeddings: basic_config.max_position_embeddings,
2259 rope_theta: basic_config.rope_theta,
2260 sliding_window: basic_config.sliding_window,
2261 use_flash_attn,
2262 norm_epsilon: basic_config.norm_epsilon,
2263 use_bias: basic_config.use_bias,
2264 quantization_config: basic_config.quantization_config,
2265 tie_word_embeddings: basic_config.tie_word_embeddings,
2266 })
2267 }
2268}
2269
2270pub struct Starcoder2Loader;
2274
2275impl NormalModelLoader for Starcoder2Loader {
2276 fn load(
2277 &self,
2278 config: &str,
2279 use_flash_attn: bool,
2280 vb: ShardedVarBuilder,
2281 normal_loading_metadata: NormalLoadingMetadata,
2282 attention_mechanism: AttentionImplementation,
2283 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2284 Ok(Box::new(models::starcoder2::Model::new(
2285 &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2286 vb,
2287 self.is_gptx(config)?,
2288 normal_loading_metadata,
2289 attention_mechanism,
2290 )?))
2291 }
2292 fn load_xlora(
2293 &self,
2294 config: &str,
2295 use_flash_attn: bool,
2296 vb: ShardedVarBuilder,
2297 lora_config: &[((String, String), LoraConfig)],
2298 xlora_config: Option<XLoraConfig>,
2299 xlora_ordering: Ordering,
2300 normal_loading_metadata: NormalLoadingMetadata,
2301 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2302 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2303 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2304 &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2305 vb,
2306 lora_config,
2307 xlora_config,
2308 xlora_ordering,
2309 self.is_gptx(config)?,
2310 normal_loading_metadata,
2311 preload_adapters,
2312 )?))
2313 }
2314 fn is_gptx(&self, _: &str) -> Result<bool> {
2315 Ok(true)
2316 }
2317 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2318 Ok(Box::new(serde_json::from_str::<Starcoder2BasicConfig>(
2319 config,
2320 )?))
2321 }
2322}
2323
2324impl IsqModelLoader for Starcoder2Loader {
2325 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2326 Ok(vec![
2327 Regex::new(r"lm_head\.(weight|bias)$")?,
2328 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2330 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2331 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2332 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2333 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2335 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2336 ])
2337 }
2338}
2339
2340impl DeviceMappedModelLoader for Starcoder2Loader {
2341 fn mapped_max_act_size_elems(
2342 &self,
2343 config: &str,
2344 params: &AutoDeviceMapParams,
2345 prompt_chunksize: usize,
2346 ) -> Result<usize> {
2347 let AutoDeviceMapParams::Text {
2348 max_seq_len: _,
2349 max_batch_size,
2350 } = params
2351 else {
2352 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2353 };
2354
2355 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2356
2357 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2358 }
2359 fn non_mapped_max_act_size_elems(
2360 &self,
2361 _config: &str,
2362 _params: &AutoDeviceMapParams,
2363 ) -> Result<usize> {
2364 Ok(0)
2365 }
2366
2367 fn non_mapped_size_in_bytes(
2368 &self,
2369 config: &str,
2370 dtype: DType,
2371 weight_pack_factor: usize,
2372 ) -> Result<usize> {
2373 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2374 let elems = {
2375 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2376 let lm_head = if !cfg.tie_word_embeddings {
2377 cfg.hidden_size * cfg.vocab_size
2378 } else {
2379 0
2380 };
2381 let norm = cfg.hidden_size + cfg.hidden_size;
2382 embed_tokens + lm_head + norm
2383 };
2384 Ok(elems * dtype.size_in_bytes())
2385 }
2386
2387 fn layer_sizes_in_bytes(
2388 &self,
2389 config: &str,
2390 dtype: DType,
2391 weight_pack_factor: usize,
2392 ) -> Result<Vec<usize>> {
2393 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2394 let per_layer_elems = {
2395 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2396 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2397
2398 let size_in = cfg.hidden_size;
2399 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2400 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2401 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2402 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2403 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2404 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2405
2406 let h_size = cfg.hidden_size;
2407 let i_size = cfg.intermediate_size;
2408 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2409 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2410
2411 input_layernorm
2412 + post_attention_layernorm
2413 + q_proj
2414 + k_proj
2415 + v_proj
2416 + o_proj
2417 + fc1
2418 + fc2
2419 };
2420 Ok(vec![
2421 per_layer_elems * dtype.size_in_bytes();
2422 cfg.num_hidden_layers
2423 ])
2424 }
2425
2426 fn num_layers(&self, config: &str) -> Result<usize> {
2427 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2428 Ok(cfg.num_hidden_layers)
2429 }
2430
2431 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2432 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2433
2434 let cfg = ModelConfigMetadata {
2435 max_seq_len: cfg.max_position_embeddings,
2436 num_layers: cfg.num_hidden_layers,
2437 hidden_size: cfg.hidden_size,
2438 num_kv_heads: cfg.num_key_value_heads,
2439 num_attn_heads: cfg.num_attention_heads,
2440 sliding_window: cfg.sliding_window,
2441 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2442 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2443 };
2444
2445 Ok(Box::new(cfg))
2446 }
2447}
2448
2449#[derive(Deserialize)]
2452struct Phi3_5MoEBasicConfig {
2453 vocab_size: usize,
2454 hidden_act: Activation,
2455 hidden_size: usize,
2456 intermediate_size: usize,
2457 num_hidden_layers: usize,
2458 num_attention_heads: usize,
2459 num_key_value_heads: usize,
2460 rms_norm_eps: f64,
2461 rope_theta: f64,
2462 rope_scaling: Option<PhiRopeScalingConfig>,
2463 max_position_embeddings: usize,
2464 original_max_position_embeddings: usize,
2465 sliding_window: Option<usize>,
2466 quantization_config: Option<QuantizedConfig>,
2467 lm_head_bias: bool,
2468 attention_bias: bool,
2469 num_local_experts: usize,
2470 router_jitter_noise: f64,
2471 #[serde(default = "word_emb_default")]
2472 tie_word_embeddings: bool,
2473}
2474
2475impl Phi3_5MoEBasicConfig {
2476 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3_5_moe::Config> {
2477 let basic_config: Self = serde_json::from_str(slice)?;
2478 Ok(models::phi3_5_moe::Config {
2479 vocab_size: basic_config.vocab_size,
2480 hidden_size: basic_config.hidden_size,
2481 intermediate_size: basic_config.intermediate_size,
2482 num_hidden_layers: basic_config.num_hidden_layers,
2483 num_attention_heads: basic_config.num_attention_heads,
2484 num_key_value_heads: basic_config.num_key_value_heads,
2485 hidden_act: basic_config.hidden_act,
2486 max_position_embeddings: basic_config.max_position_embeddings,
2487 rope_theta: basic_config.rope_theta,
2488 rms_norm_eps: basic_config.rms_norm_eps,
2489 rope_scaling: basic_config.rope_scaling,
2490 original_max_position_embeddings: basic_config.original_max_position_embeddings,
2491 use_flash_attn,
2492 sliding_window: basic_config.sliding_window,
2493 quantization_config: basic_config.quantization_config,
2494 lm_head_bias: basic_config.lm_head_bias,
2495 attention_bias: basic_config.attention_bias,
2496 num_local_experts: basic_config.num_local_experts,
2497 router_jitter_noise: basic_config.router_jitter_noise,
2498 tie_word_embeddings: basic_config.tie_word_embeddings,
2499 })
2500 }
2501}
2502
2503pub struct Phi3_5MoELoader;
2507
2508impl NormalModelLoader for Phi3_5MoELoader {
2509 fn load(
2510 &self,
2511 config: &str,
2512 use_flash_attn: bool,
2513 vb: ShardedVarBuilder,
2514 normal_loading_metadata: NormalLoadingMetadata,
2515 attention_mechanism: AttentionImplementation,
2516 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2517 Ok(Box::new(models::phi3_5_moe::Model::new(
2518 &Phi3_5MoEBasicConfig::deserialize(config, use_flash_attn)?,
2519 vb,
2520 self.is_gptx(config)?,
2521 normal_loading_metadata,
2522 attention_mechanism,
2523 )?))
2524 }
2525 fn load_xlora(
2526 &self,
2527 config: &str,
2528 use_flash_attn: bool,
2529 vb: ShardedVarBuilder,
2530 lora_config: &[((String, String), LoraConfig)],
2531 xlora_config: Option<XLoraConfig>,
2532 xlora_ordering: Ordering,
2533 normal_loading_metadata: NormalLoadingMetadata,
2534 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2535 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2536 Ok(Box::new(xlora_models::XLoraPhi3::new(
2537 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
2538 vb,
2539 lora_config,
2540 xlora_config,
2541 xlora_ordering,
2542 self.is_gptx(config)?,
2543 normal_loading_metadata,
2544 preload_adapters,
2545 )?))
2546 }
2547 fn is_gptx(&self, _: &str) -> Result<bool> {
2548 Ok(true)
2549 }
2550 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2551 Ok(Box::new(Phi3_5MoEBasicConfig::deserialize(
2552 config,
2553 use_flash_attn,
2554 )?))
2555 }
2556}
2557
2558impl IsqModelLoader for Phi3_5MoELoader {
2559 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2560 Ok(vec![
2561 Regex::new(r"lm_head\.(weight|bias)$")?,
2562 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2564 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2565 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2566 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2567 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2569 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2570 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2571 ])
2572 }
2573
2574 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2575 Ok(vec![
2576 Regex::new(r"lm_head\.(weight|bias)$")?,
2577 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2579 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2580 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2581 ])
2582 }
2583}
2584
2585impl DeviceMappedModelLoader for Phi3_5MoELoader {
2586 fn mapped_max_act_size_elems(
2587 &self,
2588 config: &str,
2589 params: &AutoDeviceMapParams,
2590 prompt_chunksize: usize,
2591 ) -> Result<usize> {
2592 let AutoDeviceMapParams::Text {
2593 max_seq_len: _,
2594 max_batch_size,
2595 } = params
2596 else {
2597 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2598 };
2599
2600 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2601
2602 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2603 }
2604 fn non_mapped_max_act_size_elems(
2605 &self,
2606 _config: &str,
2607 _params: &AutoDeviceMapParams,
2608 ) -> Result<usize> {
2609 Ok(0)
2610 }
2611
2612 fn non_mapped_size_in_bytes(
2613 &self,
2614 config: &str,
2615 dtype: DType,
2616 weight_pack_factor: usize,
2617 ) -> Result<usize> {
2618 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2619 let elems = {
2620 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2621 let lm_head = if !cfg.tie_word_embeddings {
2622 cfg.hidden_size * cfg.vocab_size
2623 } else {
2624 0
2625 };
2626 let norm = cfg.hidden_size;
2627 embed_tokens + lm_head + norm
2628 };
2629 Ok(elems * dtype.size_in_bytes())
2630 }
2631
2632 fn layer_sizes_in_bytes(
2633 &self,
2634 config: &str,
2635 dtype: DType,
2636 weight_pack_factor: usize,
2637 ) -> Result<Vec<usize>> {
2638 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2639 let per_layer_elems = {
2640 let input_layernorm = cfg.hidden_size;
2641 let post_attention_layernorm = cfg.hidden_size;
2642
2643 let size_in = cfg.hidden_size;
2644 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2645 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2646 let q_proj =
2647 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2648 let k_proj =
2649 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2650 let v_proj =
2651 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2652 let o_proj =
2653 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2654
2655 let moe_block = {
2656 let gate = cfg.hidden_size * cfg.num_local_experts;
2657 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2659 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2660 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2661 gate + cfg.num_local_experts * w1
2662 + cfg.num_local_experts * w2
2663 + cfg.num_local_experts * w3
2664 };
2665
2666 input_layernorm
2667 + post_attention_layernorm
2668 + q_proj
2669 + k_proj
2670 + v_proj
2671 + o_proj
2672 + moe_block
2673 };
2674 Ok(vec![
2675 per_layer_elems * dtype.size_in_bytes();
2676 cfg.num_hidden_layers
2677 ])
2678 }
2679
2680 fn num_layers(&self, config: &str) -> Result<usize> {
2681 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2682 Ok(cfg.num_hidden_layers)
2683 }
2684
2685 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2686 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2687
2688 let cfg = ModelConfigMetadata {
2689 max_seq_len: cfg.max_position_embeddings,
2690 num_layers: cfg.num_hidden_layers,
2691 hidden_size: cfg.hidden_size,
2692 num_kv_heads: cfg.num_key_value_heads,
2693 num_attn_heads: cfg.num_attention_heads,
2694 sliding_window: cfg.sliding_window,
2695 k_head_dim: cfg.head_dim(),
2696 v_head_dim: cfg.head_dim(),
2697 };
2698
2699 Ok(Box::new(cfg))
2700 }
2701}
2702
2703pub struct DeepSeekV2Loader;
2707
2708impl NormalModelLoader for DeepSeekV2Loader {
2709 fn load(
2710 &self,
2711 config: &str,
2712 use_flash_attn: bool,
2713 vb: ShardedVarBuilder,
2714 normal_loading_metadata: NormalLoadingMetadata,
2715 attention_mechanism: AttentionImplementation,
2716 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2717 let mut cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2718 cfg.use_flash_attn = use_flash_attn;
2719 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2720 &cfg,
2721 vb,
2722 self.is_gptx(config)?,
2723 normal_loading_metadata,
2724 attention_mechanism,
2725 )?))
2726 }
2727 fn load_xlora(
2728 &self,
2729 _config: &str,
2730 _use_flash_attn: bool,
2731 _vb: ShardedVarBuilder,
2732 _lora_config: &[((String, String), LoraConfig)],
2733 _xlora_config: Option<XLoraConfig>,
2734 _xlora_ordering: Ordering,
2735 _normal_loading_metadata: NormalLoadingMetadata,
2736 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2737 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2738 todo!()
2739 }
2740 fn is_gptx(&self, _: &str) -> Result<bool> {
2741 Ok(true)
2742 }
2743 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2744 let mut config: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2745 config.use_flash_attn = use_flash_attn;
2746 Ok(Box::new(config))
2747 }
2748}
2749
2750impl IsqModelLoader for DeepSeekV2Loader {
2751 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2752 let mut data = vec![
2753 Regex::new(r"lm_head\.(weight|bias)$")?,
2754 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2756 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2757 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2758 ];
2759 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2760 if cfg.q_lora_rank.is_some() {
2761 data.extend(vec![
2762 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2763 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2764 ]);
2765 } else {
2766 data.push(Regex::new(
2767 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2768 )?);
2769 }
2770 for layer_idx in 0..cfg.num_hidden_layers {
2771 if cfg.n_routed_experts.is_some()
2772 && layer_idx >= cfg.first_k_dense_replace
2773 && layer_idx % cfg.moe_layer_freq == 0
2774 {
2775 for i in 0..cfg.n_routed_experts.unwrap() {
2776 data.extend(vec![
2777 Regex::new(&format!(
2778 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2779 ))?,
2780 Regex::new(&format!(
2781 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2782 ))?,
2783 Regex::new(&format!(
2784 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2785 ))?,
2786 ]);
2787 }
2788 if cfg.n_shared_experts.is_some() {
2789 data.extend(vec![
2790 Regex::new(&format!(
2791 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2792 ))?,
2793 Regex::new(&format!(
2794 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2795 ))?,
2796 Regex::new(&format!(
2797 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2798 ))?,
2799 ]);
2800 }
2801 } else {
2802 data.extend(vec![
2803 Regex::new(&format!(
2804 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2805 ))?,
2806 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2807 Regex::new(&format!(
2808 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2809 ))?,
2810 ]);
2811 };
2812 }
2813 Ok(data)
2814 }
2815
2816 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2817 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2818 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2819 for layer_idx in 0..cfg.num_hidden_layers {
2820 if cfg.n_routed_experts.is_some()
2821 && layer_idx >= cfg.first_k_dense_replace
2822 && layer_idx % cfg.moe_layer_freq == 0
2823 {
2824 for i in 0..cfg.n_routed_experts.unwrap() {
2825 data.extend(vec![
2826 Regex::new(&format!(
2827 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2828 ))?,
2829 Regex::new(&format!(
2830 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2831 ))?,
2832 Regex::new(&format!(
2833 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2834 ))?,
2835 ]);
2836 }
2837 if cfg.n_shared_experts.is_some() {
2838 data.extend(vec![
2839 Regex::new(&format!(
2840 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2841 ))?,
2842 Regex::new(&format!(
2843 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2844 ))?,
2845 Regex::new(&format!(
2846 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2847 ))?,
2848 ]);
2849 }
2850 } else {
2851 data.extend(vec![
2852 Regex::new(&format!(
2853 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2854 ))?,
2855 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2856 Regex::new(&format!(
2857 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2858 ))?,
2859 ]);
2860 };
2861 }
2862 Ok(data)
2863 }
2864}
2865
2866impl DeviceMappedModelLoader for DeepSeekV2Loader {
2867 fn mapped_max_act_size_elems(
2868 &self,
2869 config: &str,
2870 params: &AutoDeviceMapParams,
2871 prompt_chunksize: usize,
2872 ) -> Result<usize> {
2873 let AutoDeviceMapParams::Text {
2874 max_seq_len: _,
2875 max_batch_size,
2876 } = params
2877 else {
2878 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2879 };
2880
2881 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2882
2883 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2884 }
2885 fn non_mapped_max_act_size_elems(
2886 &self,
2887 _config: &str,
2888 _params: &AutoDeviceMapParams,
2889 ) -> Result<usize> {
2890 Ok(0)
2891 }
2892
2893 fn non_mapped_size_in_bytes(
2894 &self,
2895 config: &str,
2896 dtype: DType,
2897 weight_pack_factor: usize,
2898 ) -> Result<usize> {
2899 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2900 let elems = {
2901 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2902 let lm_head = if !cfg.tie_word_embeddings {
2903 cfg.hidden_size * cfg.vocab_size
2904 } else {
2905 0
2906 };
2907 let norm = cfg.hidden_size;
2908 embed_tokens + lm_head + norm
2909 };
2910 Ok(elems * dtype.size_in_bytes())
2911 }
2912
2913 fn layer_sizes_in_bytes(
2914 &self,
2915 config: &str,
2916 dtype: DType,
2917 weight_pack_factor: usize,
2918 ) -> Result<Vec<usize>> {
2919 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2920 let mut per_layer_elems = Vec::new();
2921
2922 for layer_idx in 0..cfg.num_hidden_layers {
2923 let input_layernorm = cfg.hidden_size;
2924 let post_attention_layernorm = cfg.hidden_size;
2925
2926 let q_proj = match cfg.q_lora_rank {
2927 Some(lora_rank) => {
2928 let a = cfg.hidden_size * lora_rank;
2929 let norm = lora_rank;
2930 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2931 a + norm + b
2932 }
2933 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2934 };
2935 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2936 / weight_pack_factor
2937 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2938 let kv_a_layernorm = cfg.kv_lora_rank;
2939 let kv_b_proj = cfg.kv_lora_rank
2940 * cfg.num_attention_heads
2941 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2942 / weight_pack_factor;
2943 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2944 / weight_pack_factor
2945 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2946
2947 let moe_block = {
2948 let mut sum = 0;
2949 if cfg.n_routed_experts.is_some()
2950 && layer_idx >= cfg.first_k_dense_replace
2951 && layer_idx % cfg.moe_layer_freq == 0
2952 {
2953 let h_size = cfg.hidden_size;
2954 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2955 * cfg.n_routed_experts.unwrap();
2956 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2957 * cfg.n_routed_experts.unwrap();
2958 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2959 * cfg.n_routed_experts.unwrap();
2960 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2961 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2962 / weight_pack_factor;
2963 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2964 / weight_pack_factor;
2965 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2966 / weight_pack_factor;
2967 gate_proj + up_proj + down_proj
2968 } else {
2969 0
2970 };
2971 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2972 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2973 } else {
2974 let h_size = cfg.hidden_size;
2975 let i_size = cfg.intermediate_size;
2976 let gate_proj = h_size * i_size / weight_pack_factor;
2977 let up_proj = h_size * i_size / weight_pack_factor;
2978 let down_proj = i_size * h_size / weight_pack_factor;
2979 sum += gate_proj + up_proj + down_proj;
2980 }
2981 sum
2982 };
2983
2984 per_layer_elems.push(
2985 input_layernorm
2986 + post_attention_layernorm
2987 + q_proj
2988 + kv_a_layernorm
2989 + kv_a_proj_with_mqa
2990 + kv_b_proj
2991 + o_proj
2992 + moe_block,
2993 );
2994 }
2995
2996 Ok(per_layer_elems
2997 .into_iter()
2998 .map(|x| x * dtype.size_in_bytes())
2999 .collect())
3000 }
3001
3002 fn num_layers(&self, config: &str) -> Result<usize> {
3003 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
3004 Ok(cfg.num_hidden_layers)
3005 }
3006
3007 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3008 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
3009
3010 let cfg = ModelConfigMetadata {
3011 max_seq_len: cfg.max_position_embeddings,
3012 num_layers: cfg.num_hidden_layers,
3013 hidden_size: cfg.hidden_size,
3014 num_kv_heads: cfg.num_attention_heads,
3015 num_attn_heads: cfg.num_attention_heads,
3016 sliding_window: None,
3017 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3018 v_head_dim: cfg.v_head_dim,
3019 };
3020
3021 Ok(Box::new(cfg))
3022 }
3023}
3024
3025pub struct DeepSeekV3Loader;
3029
3030impl NormalModelLoader for DeepSeekV3Loader {
3031 fn load(
3032 &self,
3033 config: &str,
3034 use_flash_attn: bool,
3035 vb: ShardedVarBuilder,
3036 normal_loading_metadata: NormalLoadingMetadata,
3037 attention_mechanism: AttentionImplementation,
3038 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3039 let mut cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3040 cfg.use_flash_attn = use_flash_attn;
3041 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
3042 &cfg,
3043 vb,
3044 self.is_gptx(config)?,
3045 normal_loading_metadata,
3046 attention_mechanism,
3047 )?))
3048 }
3049 fn load_xlora(
3050 &self,
3051 _config: &str,
3052 _use_flash_attn: bool,
3053 _vb: ShardedVarBuilder,
3054 _lora_config: &[((String, String), LoraConfig)],
3055 _xlora_config: Option<XLoraConfig>,
3056 _xlora_ordering: Ordering,
3057 _normal_loading_metadata: NormalLoadingMetadata,
3058 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3059 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3060 todo!()
3061 }
3062 fn is_gptx(&self, _: &str) -> Result<bool> {
3063 Ok(true)
3064 }
3065 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3066 let mut config: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3067 config.use_flash_attn = use_flash_attn;
3068 Ok(Box::new(config))
3069 }
3070}
3071
3072impl IsqModelLoader for DeepSeekV3Loader {
3073 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3074 let mut data = vec![
3075 Regex::new(r"lm_head\.(weight|bias)$")?,
3076 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3078 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3079 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3080 ];
3081 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3082 if cfg.q_lora_rank.is_some() {
3083 data.extend(vec![
3084 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3085 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3086 ]);
3087 } else {
3088 data.push(Regex::new(
3089 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
3090 )?);
3091 }
3092 for layer_idx in 0..cfg.num_hidden_layers {
3093 if cfg.n_routed_experts.is_some()
3094 && layer_idx >= cfg.first_k_dense_replace
3095 && layer_idx % cfg.moe_layer_freq == 0
3096 {
3097 for i in 0..cfg.n_routed_experts.unwrap() {
3098 data.extend(vec![
3099 Regex::new(&format!(
3100 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3101 ))?,
3102 Regex::new(&format!(
3103 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3104 ))?,
3105 Regex::new(&format!(
3106 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3107 ))?,
3108 ]);
3109 }
3110 if cfg.n_shared_experts.is_some() {
3111 data.extend(vec![
3112 Regex::new(&format!(
3113 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3114 ))?,
3115 Regex::new(&format!(
3116 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3117 ))?,
3118 Regex::new(&format!(
3119 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3120 ))?,
3121 ]);
3122 }
3123 } else {
3124 data.extend(vec![
3125 Regex::new(&format!(
3126 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3127 ))?,
3128 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3129 Regex::new(&format!(
3130 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3131 ))?,
3132 ]);
3133 };
3134 }
3135 Ok(data)
3136 }
3137
3138 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3139 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3140 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3141 for layer_idx in 0..cfg.num_hidden_layers {
3142 if cfg.n_routed_experts.is_some()
3143 && layer_idx >= cfg.first_k_dense_replace
3144 && layer_idx % cfg.moe_layer_freq == 0
3145 {
3146 for i in 0..cfg.n_routed_experts.unwrap() {
3147 data.extend(vec![
3148 Regex::new(&format!(
3149 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3150 ))?,
3151 Regex::new(&format!(
3152 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3153 ))?,
3154 Regex::new(&format!(
3155 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3156 ))?,
3157 ]);
3158 }
3159 if cfg.n_shared_experts.is_some() {
3160 data.extend(vec![
3161 Regex::new(&format!(
3162 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3163 ))?,
3164 Regex::new(&format!(
3165 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3166 ))?,
3167 Regex::new(&format!(
3168 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3169 ))?,
3170 ]);
3171 }
3172 } else {
3173 data.extend(vec![
3174 Regex::new(&format!(
3175 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3176 ))?,
3177 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3178 Regex::new(&format!(
3179 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3180 ))?,
3181 ]);
3182 };
3183 }
3184 Ok(data)
3185 }
3186}
3187
3188impl DeviceMappedModelLoader for DeepSeekV3Loader {
3189 fn mapped_max_act_size_elems(
3190 &self,
3191 config: &str,
3192 params: &AutoDeviceMapParams,
3193 prompt_chunksize: usize,
3194 ) -> Result<usize> {
3195 let AutoDeviceMapParams::Text {
3196 max_seq_len: _,
3197 max_batch_size,
3198 } = params
3199 else {
3200 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3201 };
3202
3203 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3204
3205 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3206 }
3207 fn non_mapped_max_act_size_elems(
3208 &self,
3209 _config: &str,
3210 _params: &AutoDeviceMapParams,
3211 ) -> Result<usize> {
3212 Ok(0)
3213 }
3214
3215 fn non_mapped_size_in_bytes(
3216 &self,
3217 config: &str,
3218 dtype: DType,
3219 weight_pack_factor: usize,
3220 ) -> Result<usize> {
3221 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3222 let elems = {
3223 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3224 let lm_head = if !cfg.tie_word_embeddings {
3225 cfg.hidden_size * cfg.vocab_size
3226 } else {
3227 0
3228 };
3229 let norm = cfg.hidden_size;
3230 embed_tokens + lm_head + norm
3231 };
3232 Ok(elems * dtype.size_in_bytes())
3233 }
3234
3235 fn layer_sizes_in_bytes(
3236 &self,
3237 config: &str,
3238 dtype: DType,
3239 weight_pack_factor: usize,
3240 ) -> Result<Vec<usize>> {
3241 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3242 let mut per_layer_elems = Vec::new();
3243
3244 for layer_idx in 0..cfg.num_hidden_layers {
3245 let input_layernorm = cfg.hidden_size;
3246 let post_attention_layernorm = cfg.hidden_size;
3247
3248 let q_proj = match cfg.q_lora_rank {
3249 Some(lora_rank) => {
3250 let a = cfg.hidden_size * lora_rank;
3251 let norm = lora_rank;
3252 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
3253 a + norm + b
3254 }
3255 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
3256 };
3257 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3258 / weight_pack_factor
3259 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3260 let kv_a_layernorm = cfg.kv_lora_rank;
3261 let kv_b_proj = cfg.kv_lora_rank
3262 * cfg.num_attention_heads
3263 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3264 / weight_pack_factor;
3265 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3266 / weight_pack_factor
3267 + bias_if!(cfg.attention_bias, cfg.hidden_size);
3268
3269 let moe_block = {
3270 let mut sum = 0;
3271 if cfg.n_routed_experts.is_some()
3272 && layer_idx >= cfg.first_k_dense_replace
3273 && layer_idx % cfg.moe_layer_freq == 0
3274 {
3275 let h_size = cfg.hidden_size;
3276 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3277 * cfg.n_routed_experts.unwrap();
3278 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3279 * cfg.n_routed_experts.unwrap();
3280 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3281 * cfg.n_routed_experts.unwrap();
3282 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3283 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3284 / weight_pack_factor;
3285 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3286 / weight_pack_factor;
3287 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3288 / weight_pack_factor;
3289 gate_proj + up_proj + down_proj
3290 } else {
3291 0
3292 };
3293 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3294 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3295 } else {
3296 let h_size = cfg.hidden_size;
3297 let i_size = cfg.intermediate_size;
3298 let gate_proj = h_size * i_size / weight_pack_factor;
3299 let up_proj = h_size * i_size / weight_pack_factor;
3300 let down_proj = i_size * h_size / weight_pack_factor;
3301 sum += gate_proj + up_proj + down_proj;
3302 }
3303 sum
3304 };
3305
3306 per_layer_elems.push(
3307 input_layernorm
3308 + post_attention_layernorm
3309 + q_proj
3310 + kv_a_layernorm
3311 + kv_a_proj_with_mqa
3312 + kv_b_proj
3313 + o_proj
3314 + moe_block,
3315 );
3316 }
3317
3318 Ok(per_layer_elems
3319 .into_iter()
3320 .map(|x| x * dtype.size_in_bytes())
3321 .collect())
3322 }
3323
3324 fn num_layers(&self, config: &str) -> Result<usize> {
3325 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3326 Ok(cfg.num_hidden_layers)
3327 }
3328
3329 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3330 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3331
3332 let cfg = ModelConfigMetadata {
3333 max_seq_len: cfg.max_position_embeddings,
3334 num_layers: cfg.num_hidden_layers,
3335 hidden_size: cfg.hidden_size,
3336 num_kv_heads: cfg.num_attention_heads,
3337 num_attn_heads: cfg.num_attention_heads,
3338 sliding_window: None,
3339 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3340 v_head_dim: cfg.v_head_dim,
3341 };
3342
3343 Ok(Box::new(cfg))
3344 }
3345}
3346
3347pub struct Qwen3Loader;
3351
3352impl NormalModelLoader for Qwen3Loader {
3353 fn load(
3354 &self,
3355 config: &str,
3356 use_flash_attn: bool,
3357 vb: ShardedVarBuilder,
3358 normal_loading_metadata: NormalLoadingMetadata,
3359 attention_mechanism: AttentionImplementation,
3360 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3361 let mut cfg: models::qwen3::Config = serde_json::from_str(config)?;
3362 cfg.use_flash_attn = use_flash_attn;
3363
3364 Ok(Box::new(models::qwen3::Model::new(
3365 &cfg,
3366 vb,
3367 self.is_gptx(config)?,
3368 normal_loading_metadata,
3369 attention_mechanism,
3370 )?))
3371 }
3372 fn load_xlora(
3373 &self,
3374 _config: &str,
3375 _use_flash_attn: bool,
3376 _vb: ShardedVarBuilder,
3377 _lora_config: &[((String, String), LoraConfig)],
3378 _xlora_config: Option<XLoraConfig>,
3379 _xlora_ordering: Ordering,
3380 _normal_loading_metadata: NormalLoadingMetadata,
3381 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3382 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3383 todo!()
3384 }
3385 fn is_gptx(&self, _: &str) -> Result<bool> {
3386 Ok(true)
3387 }
3388 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3389 let mut cfg: models::qwen3::Config = serde_json::from_str(config)?;
3390 cfg.use_flash_attn = use_flash_attn;
3391
3392 Ok(Box::new(cfg))
3393 }
3394}
3395
3396impl IsqModelLoader for Qwen3Loader {
3397 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3398 Ok(vec![
3399 Regex::new(r"lm_head\.(weight|bias)$")?,
3400 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3402 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3403 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3404 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3405 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3407 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3408 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3409 ])
3410 }
3411}
3412
3413impl DeviceMappedModelLoader for Qwen3Loader {
3414 fn mapped_max_act_size_elems(
3415 &self,
3416 config: &str,
3417 params: &AutoDeviceMapParams,
3418 prompt_chunksize: usize,
3419 ) -> Result<usize> {
3420 let AutoDeviceMapParams::Text {
3421 max_seq_len: _,
3422 max_batch_size,
3423 } = params
3424 else {
3425 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3426 };
3427
3428 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3429
3430 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3431 }
3432 fn non_mapped_max_act_size_elems(
3433 &self,
3434 _config: &str,
3435 _params: &AutoDeviceMapParams,
3436 ) -> Result<usize> {
3437 Ok(0)
3438 }
3439
3440 fn non_mapped_size_in_bytes(
3441 &self,
3442 config: &str,
3443 dtype: DType,
3444 weight_pack_factor: usize,
3445 ) -> Result<usize> {
3446 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3447 let elems = {
3448 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3449 let lm_head = if !cfg.tie_word_embeddings {
3450 cfg.hidden_size * cfg.vocab_size
3451 } else {
3452 0
3453 };
3454 let norm = cfg.hidden_size;
3455 embed_tokens + lm_head + norm
3456 };
3457 Ok(elems * dtype.size_in_bytes())
3458 }
3459
3460 fn layer_sizes_in_bytes(
3461 &self,
3462 config: &str,
3463 dtype: DType,
3464 weight_pack_factor: usize,
3465 ) -> Result<Vec<usize>> {
3466 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3467 let per_layer_elems = {
3468 let input_layernorm = cfg.hidden_size;
3469 let post_attention_layernorm = cfg.hidden_size;
3470
3471 let size_in = cfg.hidden_size;
3472 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3473 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3474 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3475 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3476 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3477 let o_proj = size_q * size_in / weight_pack_factor;
3478
3479 let h_size = cfg.hidden_size;
3480 let i_size = cfg.intermediate_size;
3481 let gate_proj = h_size * i_size / weight_pack_factor;
3482 let up_proj = h_size * i_size / weight_pack_factor;
3483 let down_proj = i_size * h_size / weight_pack_factor;
3484
3485 let q_norm = cfg.head_dim();
3486 let k_norm = cfg.head_dim();
3487
3488 input_layernorm
3489 + post_attention_layernorm
3490 + q_proj
3491 + k_proj
3492 + v_proj
3493 + o_proj
3494 + gate_proj
3495 + up_proj
3496 + down_proj
3497 + q_norm
3498 + k_norm
3499 };
3500 Ok(vec![
3501 per_layer_elems * dtype.size_in_bytes();
3502 cfg.num_hidden_layers
3503 ])
3504 }
3505
3506 fn num_layers(&self, config: &str) -> Result<usize> {
3507 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3508 Ok(cfg.num_hidden_layers)
3509 }
3510
3511 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3512 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3513
3514 let cfg = ModelConfigMetadata {
3515 max_seq_len: cfg.max_position_embeddings,
3516 num_layers: cfg.num_hidden_layers,
3517 hidden_size: cfg.hidden_size,
3518 num_kv_heads: cfg.num_key_value_heads,
3519 num_attn_heads: cfg.num_attention_heads,
3520 sliding_window: cfg.sliding_window,
3521 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3522 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3523 };
3524
3525 Ok(Box::new(cfg))
3526 }
3527}
3528
3529pub struct Qwen3MoELoader;
3533
3534impl NormalModelLoader for Qwen3MoELoader {
3535 fn load(
3536 &self,
3537 config: &str,
3538 use_flash_attn: bool,
3539 vb: ShardedVarBuilder,
3540 normal_loading_metadata: NormalLoadingMetadata,
3541 attention_mechanism: AttentionImplementation,
3542 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3543 let mut cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3544 cfg.use_flash_attn = use_flash_attn;
3545
3546 Ok(Box::new(models::qwen3_moe::Model::new(
3547 &cfg,
3548 vb,
3549 self.is_gptx(config)?,
3550 normal_loading_metadata,
3551 attention_mechanism,
3552 )?))
3553 }
3554 fn load_xlora(
3555 &self,
3556 _config: &str,
3557 _use_flash_attn: bool,
3558 _vb: ShardedVarBuilder,
3559 _lora_config: &[((String, String), LoraConfig)],
3560 _xlora_config: Option<XLoraConfig>,
3561 _xlora_ordering: Ordering,
3562 _normal_loading_metadata: NormalLoadingMetadata,
3563 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3564 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3565 todo!()
3566 }
3567 fn is_gptx(&self, _: &str) -> Result<bool> {
3568 Ok(true)
3569 }
3570 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3571 let mut cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3572 cfg.use_flash_attn = use_flash_attn;
3573
3574 Ok(Box::new(cfg))
3575 }
3576}
3577
3578impl IsqModelLoader for Qwen3MoELoader {
3579 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3580 Ok(vec![
3581 Regex::new(r"lm_head\.(weight|bias)$")?,
3582 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3584 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3585 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3586 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3587 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3589 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3590 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3591 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3593 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3594 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3595 ])
3596 }
3597}
3598
3599impl DeviceMappedModelLoader for Qwen3MoELoader {
3600 fn mapped_max_act_size_elems(
3601 &self,
3602 config: &str,
3603 params: &AutoDeviceMapParams,
3604 prompt_chunksize: usize,
3605 ) -> Result<usize> {
3606 let AutoDeviceMapParams::Text {
3607 max_seq_len: _,
3608 max_batch_size,
3609 } = params
3610 else {
3611 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3612 };
3613
3614 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3615
3616 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3617 }
3618 fn non_mapped_max_act_size_elems(
3619 &self,
3620 _config: &str,
3621 _params: &AutoDeviceMapParams,
3622 ) -> Result<usize> {
3623 Ok(0)
3624 }
3625
3626 fn non_mapped_size_in_bytes(
3627 &self,
3628 config: &str,
3629 dtype: DType,
3630 weight_pack_factor: usize,
3631 ) -> Result<usize> {
3632 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3633 let elems = {
3634 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3635 let lm_head = if !cfg.tie_word_embeddings {
3636 cfg.hidden_size * cfg.vocab_size
3637 } else {
3638 0
3639 };
3640 let norm = cfg.hidden_size;
3641 embed_tokens + lm_head + norm
3642 };
3643 Ok(elems * dtype.size_in_bytes())
3644 }
3645
3646 fn layer_sizes_in_bytes(
3647 &self,
3648 config: &str,
3649 dtype: DType,
3650 weight_pack_factor: usize,
3651 ) -> Result<Vec<usize>> {
3652 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3653 let mut layer_sizes_in_bytes = Vec::new();
3654 for layer_idx in 0..cfg.num_hidden_layers {
3655 let input_layernorm = cfg.hidden_size;
3656 let post_attention_layernorm = cfg.hidden_size;
3657
3658 let size_in = cfg.hidden_size;
3659 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3660 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3661 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3662 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3663 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3664 let o_proj = size_q * size_in / weight_pack_factor;
3665
3666 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3667 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3668 {
3669 let expert_size = {
3670 let h_size = cfg.hidden_size;
3671 let i_size = cfg.moe_intermediate_size;
3672 let gate_proj = h_size * i_size / weight_pack_factor;
3673 let up_proj = h_size * i_size / weight_pack_factor;
3674 let down_proj = i_size * h_size / weight_pack_factor;
3675 gate_proj + up_proj + down_proj
3676 };
3677 expert_size * cfg.num_experts
3678 } else {
3679 let h_size = cfg.hidden_size;
3680 let i_size = cfg.intermediate_size;
3681 let gate_proj = h_size * i_size / weight_pack_factor;
3682 let up_proj = h_size * i_size / weight_pack_factor;
3683 let down_proj = i_size * h_size / weight_pack_factor;
3684 gate_proj + up_proj + down_proj
3685 };
3686
3687 let q_norm = cfg.head_dim();
3688 let k_norm = cfg.head_dim();
3689
3690 let size_elems = input_layernorm
3691 + post_attention_layernorm
3692 + q_proj
3693 + k_proj
3694 + v_proj
3695 + o_proj
3696 + mlp_size
3697 + q_norm
3698 + k_norm;
3699
3700 let size_in_bytes = size_elems * dtype.size_in_bytes();
3701 layer_sizes_in_bytes.push(size_in_bytes);
3702 }
3703
3704 Ok(layer_sizes_in_bytes)
3705 }
3706
3707 fn num_layers(&self, config: &str) -> Result<usize> {
3708 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3709 Ok(cfg.num_hidden_layers)
3710 }
3711
3712 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3713 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3714
3715 let cfg = ModelConfigMetadata {
3716 max_seq_len: cfg.max_position_embeddings,
3717 num_layers: cfg.num_hidden_layers,
3718 hidden_size: cfg.hidden_size,
3719 num_kv_heads: cfg.num_key_value_heads,
3720 num_attn_heads: cfg.num_attention_heads,
3721 sliding_window: cfg.sliding_window,
3722 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3723 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3724 };
3725
3726 Ok(Box::new(cfg))
3727 }
3728}