1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "glm4moelite")]
175 GLM4MoeLite,
176 #[serde(rename = "glm4moe")]
177 GLM4Moe,
178 #[serde(rename = "qwen3moe")]
179 Qwen3Moe,
180 #[serde(rename = "smollm3")]
181 SmolLm3,
182 #[serde(rename = "granitemoehybrid")]
183 GraniteMoeHybrid,
184 #[serde(rename = "gpt_oss")]
185 GptOss,
186}
187
188impl NormalLoaderType {
190 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
191 match name {
192 "MistralForCausalLM" => Ok(Self::Mistral),
193 "MixtralForCausalLM" => Ok(Self::Mixtral),
194 "GemmaForCausalLM" => Ok(Self::Gemma),
195 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
196 "PhiForCausalLM" => Ok(Self::Phi2),
197 "Phi3ForCausalLM" => Ok(Self::Phi3),
198 "LlamaForCausalLM" => Ok(Self::Llama),
199 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
200 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
201 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
202 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
203 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
204 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
205 "Glm4ForCausalLM" => Ok(Self::GLM4),
206 "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
207 "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
208 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
209 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
210 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
211 "GptOssForCausalLM" => Ok(Self::GptOss),
212 other => anyhow::bail!(
213 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
214 ),
215 }
216 }
217}
218
219impl FromStr for NormalLoaderType {
220 type Err = String;
221 fn from_str(s: &str) -> Result<Self, Self::Err> {
222 match s {
223 "mistral" => Ok(Self::Mistral),
224 "gemma" => Ok(Self::Gemma),
225 "mixtral" => Ok(Self::Mixtral),
226 "llama" => Ok(Self::Llama),
227 "phi2" => Ok(Self::Phi2),
228 "phi3" => Ok(Self::Phi3),
229 "qwen2" => Ok(Self::Qwen2),
230 "gemma2" => Ok(Self::Gemma2),
231 "starcoder2" => Ok(Self::Starcoder2),
232 "phi3.5moe" => Ok(Self::Phi3_5MoE),
233 "deepseekv2" => Ok(Self::DeepSeekV2),
234 "deepseekv3" => Ok(Self::DeepSeekV3),
235 "qwen3" => Ok(Self::Qwen3),
236 "glm4" => Ok(Self::GLM4),
237 "glm4moelite" => Ok(Self::GLM4MoeLite),
238 "glm4moe" => Ok(Self::GLM4Moe),
239 "qwen3moe" => Ok(Self::Qwen3Moe),
240 "smollm3" => Ok(Self::SmolLm3),
241 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
242 "gpt_oss" => Ok(Self::GptOss),
243 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
244 }
245 }
246}
247
248impl Display for NormalLoaderType {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 match self {
251 Self::Gemma => write!(f, "gemma"),
252 Self::Gemma2 => write!(f, "gemma2"),
253 Self::Llama => write!(f, "llama"),
254 Self::Mistral => write!(f, "mistral"),
255 Self::Mixtral => write!(f, "mixtral"),
256 Self::Phi2 => write!(f, "phi2"),
257 Self::Phi3 => write!(f, "phi3"),
258 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
259 Self::Qwen2 => write!(f, "qwen2"),
260 Self::Starcoder2 => write!(f, "starcoder2"),
261 Self::DeepSeekV2 => write!(f, "deepseekv2"),
262 Self::DeepSeekV3 => write!(f, "deepseekv3"),
263 Self::Qwen3 => write!(f, "qwen3"),
264 Self::GLM4 => write!(f, "glm4"),
265 Self::GLM4MoeLite => write!(f, "glm4moelite"),
266 Self::GLM4Moe => write!(f, "glm4moe"),
267 Self::Qwen3Moe => write!(f, "qwen3moe"),
268 Self::SmolLm3 => write!(f, "smollm3"),
269 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
270 Self::GptOss => write!(f, "gpt_oss"),
271 }
272 }
273}
274
275macro_rules! bias_if {
276 ($cond:expr, $size:expr) => {
277 if $cond {
278 $size
279 } else {
280 0
281 }
282 };
283}
284
285pub struct AutoNormalLoader;
287
288#[derive(Deserialize)]
289struct AutoNormalLoaderConfig {
290 architectures: Vec<String>,
291}
292
293impl AutoNormalLoader {
294 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
295 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
296 if auto_cfg.architectures.len() != 1 {
297 anyhow::bail!("Expected to have one name for `architectures` config field.")
298 }
299
300 let name = &auto_cfg.architectures[0];
301
302 let tp = NormalLoaderType::from_causal_lm_name(name)?;
303
304 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
305
306 match tp {
307 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
308 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
309 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
310 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
311 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
312 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
313 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
314 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
315 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
316 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
317 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
318 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
319 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
320 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
321 NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
322 NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
323 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
324 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
325 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
326 NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
327 }
328 }
329}
330
331impl NormalModelLoader for AutoNormalLoader {
332 fn load(
333 &self,
334 config: &str,
335 vb: ShardedVarBuilder,
336 normal_loading_metadata: NormalLoadingMetadata,
337 attention_mechanism: AttentionImplementation,
338 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
340 }
341 fn load_xlora(
342 &self,
343 config: &str,
344 vb: ShardedVarBuilder,
345 lora_config: &[((String, String), LoraConfig)],
346 xlora_config: Option<XLoraConfig>,
347 xlora_ordering: Ordering,
348 normal_loading_metadata: NormalLoadingMetadata,
349 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
350 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
351 Self::get_loader(config)?.load_xlora(
352 config,
353 vb,
354 lora_config,
355 xlora_config,
356 xlora_ordering,
357 normal_loading_metadata,
358 preload_adapters,
359 )
360 }
361 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
362 Self::get_loader(config)?.get_config_repr(config)
363 }
364 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
365 Self::get_loader(config)?.supports_paged_attention(config)
366 }
367 fn is_gptx(&self, config: &str) -> Result<bool> {
368 Self::get_loader(config)?.is_gptx(config)
369 }
370}
371
372impl IsqModelLoader for AutoNormalLoader {
373 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
374 Self::get_loader(config)?.immediate_isq_predicates(config)
375 }
376 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
377 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
378 }
379 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
380 Self::get_loader(config)?.isq_layer_regexes(config)
381 }
382 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
383 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
384 }
385}
386
387impl DeviceMappedModelLoader for AutoNormalLoader {
388 fn non_mapped_size_in_bytes(
389 &self,
390 config: &str,
391 dtype: DType,
392 weight_pack_factor: usize,
393 _matformer_config: Option<&MatformerSliceConfig>,
394 ) -> Result<usize> {
395 Self::get_loader(config)?.non_mapped_size_in_bytes(
396 config,
397 dtype,
398 weight_pack_factor,
399 _matformer_config,
400 )
401 }
402 fn num_layers(&self, config: &str) -> Result<usize> {
403 Self::get_loader(config)?.num_layers(config)
404 }
405 fn layer_sizes_in_bytes(
406 &self,
407 config: &str,
408 dtype: DType,
409 weight_pack_factor: usize,
410 _matformer_config: Option<&MatformerSliceConfig>,
411 ) -> Result<Vec<usize>> {
412 Self::get_loader(config)?.layer_sizes_in_bytes(
413 config,
414 dtype,
415 weight_pack_factor,
416 _matformer_config,
417 )
418 }
419 fn mapped_max_act_size_elems(
420 &self,
421 config: &str,
422 params: &super::AutoDeviceMapParams,
423 ) -> Result<usize> {
424 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
425 }
426 fn non_mapped_max_act_size_elems(
427 &self,
428 _config: &str,
429 _params: &AutoDeviceMapParams,
430 ) -> Result<usize> {
431 Ok(0)
432 }
433 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
434 Self::get_loader(config)?.model_config(config)
435 }
436}
437
438pub struct MistralLoader;
441
442impl NormalModelLoader for MistralLoader {
443 fn load(
444 &self,
445 config: &str,
446 vb: ShardedVarBuilder,
447 normal_loading_metadata: NormalLoadingMetadata,
448 attention_mechanism: AttentionImplementation,
449 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
450 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
451 Ok(Box::new(models::mistral::Model::new(
452 &cfg,
453 vb,
454 self.is_gptx(config)?,
455 normal_loading_metadata,
456 attention_mechanism,
457 )?))
458 }
459 fn load_xlora(
460 &self,
461 config: &str,
462 vb: ShardedVarBuilder,
463 lora_config: &[((String, String), LoraConfig)],
464 xlora_config: Option<XLoraConfig>,
465 xlora_ordering: Ordering,
466 normal_loading_metadata: NormalLoadingMetadata,
467 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
468 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
469 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
470 Ok(Box::new(xlora_models::XLoraMistral::new(
471 &cfg,
472 vb,
473 lora_config,
474 xlora_config,
475 xlora_ordering,
476 self.is_gptx(config)?,
477 normal_loading_metadata,
478 preload_adapters,
479 )?))
480 }
481 fn is_gptx(&self, _: &str) -> Result<bool> {
482 Ok(true)
483 }
484 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
485 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
486 Ok(Box::new(cfg))
487 }
488}
489
490impl IsqModelLoader for MistralLoader {
491 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
492 Ok(vec![
493 Regex::new(r"lm_head\.(weight|bias)$")?,
494 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
496 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
497 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
498 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
499 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
501 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
502 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
503 ])
504 }
505 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
506 self.isq_layer_regexes(config)
507 }
508}
509
510impl DeviceMappedModelLoader for MistralLoader {
511 fn mapped_max_act_size_elems(
512 &self,
513 config: &str,
514 params: &AutoDeviceMapParams,
515 ) -> Result<usize> {
516 let AutoDeviceMapParams::Text {
517 max_seq_len,
518 max_batch_size,
519 } = params
520 else {
521 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
522 };
523
524 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
525
526 Ok(
527 max_batch_size
528 * cfg.num_attention_heads
529 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
530 )
531 }
532 fn non_mapped_max_act_size_elems(
533 &self,
534 _config: &str,
535 _params: &AutoDeviceMapParams,
536 ) -> Result<usize> {
537 Ok(0)
538 }
539
540 fn non_mapped_size_in_bytes(
541 &self,
542 config: &str,
543 dtype: DType,
544 weight_pack_factor: usize,
545 _matformer_config: Option<&MatformerSliceConfig>,
546 ) -> Result<usize> {
547 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
548
549 let elems = {
550 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
551 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
553 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
554 } else {
555 0
556 };
557 let norm = cfg.hidden_size;
558 embed_tokens + lm_head + norm
559 };
560 Ok(elems * dtype.size_in_bytes())
561 }
562
563 fn layer_sizes_in_bytes(
564 &self,
565 config: &str,
566 dtype: DType,
567 weight_pack_factor: usize,
568 _matformer_config: Option<&MatformerSliceConfig>,
569 ) -> Result<Vec<usize>> {
570 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
571
572 let per_layer_elems = {
573 let input_layernorm = cfg.hidden_size;
574 let post_attention_layernorm = cfg.hidden_size;
575
576 let size_in = cfg.hidden_size;
577 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
578 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
579 let q_proj = size_in * size_q / weight_pack_factor;
580 let k_proj = size_in * size_kv / weight_pack_factor;
581 let v_proj = size_in * size_kv / weight_pack_factor;
582 let o_proj = size_q * size_in / weight_pack_factor;
583
584 let h_size = cfg.hidden_size;
585 let i_size = cfg.intermediate_size;
586 let gate_proj = h_size * i_size / weight_pack_factor;
587 let up_proj = h_size * i_size / weight_pack_factor;
588 let down_proj = i_size * h_size / weight_pack_factor;
589
590 input_layernorm
591 + post_attention_layernorm
592 + q_proj
593 + k_proj
594 + v_proj
595 + o_proj
596 + gate_proj
597 + up_proj
598 + down_proj
599 };
600 Ok(vec![
601 per_layer_elems * dtype.size_in_bytes();
602 cfg.num_hidden_layers
603 ])
604 }
605
606 fn num_layers(&self, config: &str) -> Result<usize> {
607 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
608 Ok(cfg.num_hidden_layers)
609 }
610
611 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
612 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
613
614 let cfg = ModelConfigMetadata {
615 max_seq_len: cfg.max_position_embeddings,
616 num_layers: cfg.num_hidden_layers,
617 hidden_size: cfg.hidden_size,
618 num_kv_heads: cfg.num_key_value_heads,
619 num_attn_heads: cfg.num_attention_heads,
620 sliding_window: cfg.sliding_window,
621 k_head_dim: cfg.head_dim(),
622 v_head_dim: cfg.head_dim(),
623 };
624
625 Ok(Box::new(cfg))
626 }
627}
628
629pub struct GemmaLoader;
635
636impl NormalModelLoader for GemmaLoader {
637 fn load(
638 &self,
639 config: &str,
640 vb: ShardedVarBuilder,
641 normal_loading_metadata: NormalLoadingMetadata,
642 attention_mechanism: AttentionImplementation,
643 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
644 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
645
646 Ok(Box::new(models::gemma::Model::new(
647 &cfg,
648 vb,
649 self.is_gptx(config)?,
650 normal_loading_metadata,
651 attention_mechanism,
652 )?))
653 }
654 fn load_xlora(
655 &self,
656 config: &str,
657 vb: ShardedVarBuilder,
658 lora_config: &[((String, String), LoraConfig)],
659 xlora_config: Option<XLoraConfig>,
660 xlora_ordering: Ordering,
661 normal_loading_metadata: NormalLoadingMetadata,
662 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
663 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
664 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
665
666 Ok(Box::new(xlora_models::XLoraGemma::new(
667 &cfg,
668 vb,
669 lora_config,
670 xlora_config,
671 xlora_ordering,
672 self.is_gptx(config)?,
673 normal_loading_metadata,
674 preload_adapters,
675 )?))
676 }
677 fn is_gptx(&self, _: &str) -> Result<bool> {
678 Ok(true)
679 }
680 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
681 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
682 Ok(Box::new(cfg))
683 }
684}
685
686impl IsqModelLoader for GemmaLoader {
687 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
688 Ok(vec![
689 Regex::new(r"lm_head\.(weight|bias)$")?,
690 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
692 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
693 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
694 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
695 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
697 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
698 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
699 ])
700 }
701 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
702 self.isq_layer_regexes(config)
703 }
704}
705
706impl DeviceMappedModelLoader for GemmaLoader {
707 fn mapped_max_act_size_elems(
708 &self,
709 config: &str,
710 params: &AutoDeviceMapParams,
711 ) -> Result<usize> {
712 let AutoDeviceMapParams::Text {
713 max_seq_len,
714 max_batch_size,
715 } = params
716 else {
717 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
718 };
719
720 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
721
722 Ok(
723 max_batch_size
724 * cfg.num_attention_heads
725 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
726 )
727 }
728 fn non_mapped_max_act_size_elems(
729 &self,
730 _config: &str,
731 _params: &AutoDeviceMapParams,
732 ) -> Result<usize> {
733 Ok(0)
734 }
735
736 fn non_mapped_size_in_bytes(
737 &self,
738 config: &str,
739 dtype: DType,
740 weight_pack_factor: usize,
741 _matformer_config: Option<&MatformerSliceConfig>,
742 ) -> Result<usize> {
743 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
744
745 let elems = {
746 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
747 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
749 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
750 } else {
751 0
752 };
753 let norm = cfg.hidden_size;
754 embed_tokens + lm_head + norm
755 };
756 Ok(elems * dtype.size_in_bytes())
757 }
758
759 fn layer_sizes_in_bytes(
760 &self,
761 config: &str,
762 dtype: DType,
763 weight_pack_factor: usize,
764 _matformer_config: Option<&MatformerSliceConfig>,
765 ) -> Result<Vec<usize>> {
766 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
767
768 let per_layer_elems = {
769 let input_layernorm = cfg.hidden_size;
770 let post_attention_layernorm = cfg.hidden_size;
771
772 let size_in = cfg.hidden_size;
773 let size_q = cfg.head_dim * cfg.num_attention_heads;
774 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
775 let q_proj =
776 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
777 let k_proj =
778 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
779 let v_proj =
780 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
781 let o_proj =
782 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
783
784 let h_size = cfg.hidden_size;
785 let i_size = cfg.intermediate_size;
786 let gate_proj = h_size * i_size / weight_pack_factor;
787 let up_proj = h_size * i_size / weight_pack_factor;
788 let down_proj = i_size * h_size / weight_pack_factor;
789
790 input_layernorm
791 + post_attention_layernorm
792 + q_proj
793 + k_proj
794 + v_proj
795 + o_proj
796 + gate_proj
797 + up_proj
798 + down_proj
799 };
800 Ok(vec![
801 per_layer_elems * dtype.size_in_bytes();
802 cfg.num_hidden_layers
803 ])
804 }
805
806 fn num_layers(&self, config: &str) -> Result<usize> {
807 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
808 Ok(cfg.num_hidden_layers)
809 }
810
811 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
812 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
813
814 let cfg = ModelConfigMetadata {
815 max_seq_len: cfg.max_position_embeddings,
816 num_layers: cfg.num_hidden_layers,
817 hidden_size: cfg.hidden_size,
818 num_kv_heads: cfg.num_key_value_heads,
819 num_attn_heads: cfg.num_attention_heads,
820 sliding_window: None,
821 k_head_dim: cfg.head_dim,
822 v_head_dim: cfg.head_dim,
823 };
824
825 Ok(Box::new(cfg))
826 }
827}
828
829pub struct LlamaLoader;
835
836impl NormalModelLoader for LlamaLoader {
837 fn load(
838 &self,
839 config: &str,
840 vb: ShardedVarBuilder,
841 normal_loading_metadata: NormalLoadingMetadata,
842 attention_mechanism: AttentionImplementation,
843 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
844 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
845
846 Ok(Box::new(models::llama::Llama::new(
847 &cfg,
848 vb,
849 self.is_gptx(config)?,
850 normal_loading_metadata,
851 attention_mechanism,
852 )?))
853 }
854 fn load_xlora(
855 &self,
856 config: &str,
857 vb: ShardedVarBuilder,
858 lora_config: &[((String, String), LoraConfig)],
859 xlora_config: Option<XLoraConfig>,
860 xlora_ordering: Ordering,
861 normal_loading_metadata: NormalLoadingMetadata,
862 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
863 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
864 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
865
866 Ok(Box::new(xlora_models::XLoraLlama::new(
867 &cfg,
868 vb,
869 lora_config,
870 xlora_config,
871 xlora_ordering,
872 self.is_gptx(config)?,
873 normal_loading_metadata,
874 preload_adapters,
875 )?))
876 }
877 fn is_gptx(&self, _: &str) -> Result<bool> {
878 Ok(true)
879 }
880 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
881 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
882 Ok(Box::new(cfg))
883 }
884}
885
886impl IsqModelLoader for LlamaLoader {
887 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
888 Ok(vec![
889 Regex::new(r"lm_head\.(weight|bias)$")?,
890 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
892 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
893 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
894 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
895 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
897 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
898 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
899 ])
900 }
901 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
902 self.isq_layer_regexes(config)
903 }
904}
905
906impl DeviceMappedModelLoader for LlamaLoader {
907 fn mapped_max_act_size_elems(
908 &self,
909 config: &str,
910 params: &AutoDeviceMapParams,
911 ) -> Result<usize> {
912 let AutoDeviceMapParams::Text {
913 max_seq_len,
914 max_batch_size,
915 } = params
916 else {
917 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
918 };
919
920 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
921
922 Ok(
923 max_batch_size
924 * cfg.num_attention_heads
925 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
926 )
927 }
928 fn non_mapped_max_act_size_elems(
929 &self,
930 _config: &str,
931 _params: &AutoDeviceMapParams,
932 ) -> Result<usize> {
933 Ok(0)
934 }
935
936 fn non_mapped_size_in_bytes(
937 &self,
938 config: &str,
939 dtype: DType,
940 weight_pack_factor: usize,
941 _matformer_config: Option<&MatformerSliceConfig>,
942 ) -> Result<usize> {
943 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
944
945 let elems = {
946 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
947 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
949 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
950 } else {
951 0
952 };
953 let norm = cfg.hidden_size;
954 embed_tokens + lm_head + norm
955 };
956 Ok(elems * dtype.size_in_bytes())
957 }
958
959 fn layer_sizes_in_bytes(
960 &self,
961 config: &str,
962 dtype: DType,
963 weight_pack_factor: usize,
964 _matformer_config: Option<&MatformerSliceConfig>,
965 ) -> Result<Vec<usize>> {
966 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
967
968 let per_layer_elems = {
969 let input_layernorm = cfg.hidden_size;
970 let post_attention_layernorm = cfg.hidden_size;
971
972 let size_in = cfg.hidden_size;
973 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
974 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
975 let q_proj = size_in * size_q / weight_pack_factor;
976 let k_proj = size_in * size_kv / weight_pack_factor;
977 let v_proj = size_in * size_kv / weight_pack_factor;
978 let o_proj = size_q * size_in / weight_pack_factor;
979
980 let h_size = cfg.hidden_size;
981 let i_size = cfg.intermediate_size;
982 let gate_proj = h_size * i_size / weight_pack_factor;
983 let up_proj = h_size * i_size / weight_pack_factor;
984 let down_proj = i_size * h_size / weight_pack_factor;
985
986 input_layernorm
987 + post_attention_layernorm
988 + q_proj
989 + k_proj
990 + v_proj
991 + o_proj
992 + gate_proj
993 + up_proj
994 + down_proj
995 };
996 Ok(vec![
997 per_layer_elems * dtype.size_in_bytes();
998 cfg.num_hidden_layers
999 ])
1000 }
1001
1002 fn num_layers(&self, config: &str) -> Result<usize> {
1003 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1004
1005 Ok(cfg.num_hidden_layers)
1006 }
1007 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1008 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1009
1010 let cfg = ModelConfigMetadata {
1011 max_seq_len: cfg.max_position_embeddings,
1012 num_layers: cfg.num_hidden_layers,
1013 hidden_size: cfg.hidden_size,
1014 num_kv_heads: cfg.num_key_value_heads,
1015 num_attn_heads: cfg.num_attention_heads,
1016 sliding_window: None,
1017 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1018 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1019 };
1020
1021 Ok(Box::new(cfg))
1022 }
1023}
1024
1025pub struct MixtralLoader;
1028
1029impl NormalModelLoader for MixtralLoader {
1030 fn load(
1031 &self,
1032 config: &str,
1033 vb: ShardedVarBuilder,
1034 normal_loading_metadata: NormalLoadingMetadata,
1035 attention_mechanism: AttentionImplementation,
1036 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1037 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1038
1039 Ok(Box::new(models::mixtral::Model::new(
1040 &cfg,
1041 vb,
1042 self.is_gptx(config)?,
1043 normal_loading_metadata,
1044 attention_mechanism,
1045 )?))
1046 }
1047 fn load_xlora(
1048 &self,
1049 config: &str,
1050 vb: ShardedVarBuilder,
1051 lora_config: &[((String, String), LoraConfig)],
1052 xlora_config: Option<XLoraConfig>,
1053 xlora_ordering: Ordering,
1054 normal_loading_metadata: NormalLoadingMetadata,
1055 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1056 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1057 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1058
1059 Ok(Box::new(xlora_models::XLoraMixtral::new(
1060 &cfg,
1061 vb,
1062 lora_config,
1063 xlora_config,
1064 xlora_ordering,
1065 self.is_gptx(config)?,
1066 normal_loading_metadata,
1067 preload_adapters,
1068 )?))
1069 }
1070 fn is_gptx(&self, _: &str) -> Result<bool> {
1071 Ok(true)
1072 }
1073 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1074 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1075
1076 Ok(Box::new(cfg))
1077 }
1078}
1079
1080impl IsqModelLoader for MixtralLoader {
1081 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1082 Ok(vec![
1083 Regex::new(r"lm_head\.(weight|bias)$")?,
1084 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1086 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1087 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1088 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1089 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1091 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1092 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1093 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1094 ])
1095 }
1096 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1097 self.isq_layer_regexes(config)
1098 }
1099}
1100
1101impl DeviceMappedModelLoader for MixtralLoader {
1102 fn mapped_max_act_size_elems(
1103 &self,
1104 config: &str,
1105 params: &AutoDeviceMapParams,
1106 ) -> Result<usize> {
1107 let AutoDeviceMapParams::Text {
1108 max_seq_len,
1109 max_batch_size,
1110 } = params
1111 else {
1112 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1113 };
1114
1115 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1116
1117 Ok(
1118 max_batch_size
1119 * cfg.num_attention_heads
1120 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1121 )
1122 }
1123 fn non_mapped_max_act_size_elems(
1124 &self,
1125 _config: &str,
1126 _params: &AutoDeviceMapParams,
1127 ) -> Result<usize> {
1128 Ok(0)
1129 }
1130
1131 fn non_mapped_size_in_bytes(
1132 &self,
1133 config: &str,
1134 dtype: DType,
1135 weight_pack_factor: usize,
1136 _matformer_config: Option<&MatformerSliceConfig>,
1137 ) -> Result<usize> {
1138 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1139
1140 let elems = {
1141 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1142 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1144 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1145 } else {
1146 0
1147 };
1148 let norm = cfg.hidden_size;
1149 embed_tokens + lm_head + norm
1150 };
1151 Ok(elems * dtype.size_in_bytes())
1152 }
1153
1154 fn layer_sizes_in_bytes(
1155 &self,
1156 config: &str,
1157 dtype: DType,
1158 weight_pack_factor: usize,
1159 _matformer_config: Option<&MatformerSliceConfig>,
1160 ) -> Result<Vec<usize>> {
1161 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1162
1163 let per_layer_elems = {
1164 let input_layernorm = cfg.hidden_size;
1165 let post_attention_layernorm = cfg.hidden_size;
1166
1167 let size_in = cfg.hidden_size;
1168 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1169 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1170 let q_proj = size_in * size_q / weight_pack_factor;
1171 let k_proj = size_in * size_kv / weight_pack_factor;
1172 let v_proj = size_in * size_kv / weight_pack_factor;
1173 let o_proj = size_q * size_in / weight_pack_factor;
1174
1175 let moe_block = {
1176 let gate = cfg.hidden_size * cfg.num_local_experts;
1177 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1179 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1180 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1181 gate + cfg.num_local_experts * w1
1182 + cfg.num_local_experts * w2
1183 + cfg.num_local_experts * w3
1184 };
1185
1186 input_layernorm
1187 + post_attention_layernorm
1188 + q_proj
1189 + k_proj
1190 + v_proj
1191 + o_proj
1192 + moe_block
1193 };
1194 Ok(vec![
1195 per_layer_elems * dtype.size_in_bytes();
1196 cfg.num_hidden_layers
1197 ])
1198 }
1199
1200 fn num_layers(&self, config: &str) -> Result<usize> {
1201 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1202
1203 Ok(cfg.num_hidden_layers)
1204 }
1205
1206 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1207 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1208
1209 let cfg = ModelConfigMetadata {
1210 max_seq_len: cfg.max_position_embeddings,
1211 num_layers: cfg.num_hidden_layers,
1212 hidden_size: cfg.hidden_size,
1213 num_kv_heads: cfg.num_key_value_heads,
1214 num_attn_heads: cfg.num_attention_heads,
1215 sliding_window: cfg.sliding_window,
1216 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1217 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1218 };
1219
1220 Ok(Box::new(cfg))
1221 }
1222}
1223
1224pub struct Phi2Loader;
1230
1231impl NormalModelLoader for Phi2Loader {
1232 fn load(
1233 &self,
1234 config: &str,
1235 vb: ShardedVarBuilder,
1236 normal_loading_metadata: NormalLoadingMetadata,
1237 attention_mechanism: AttentionImplementation,
1238 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1239 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1240
1241 Ok(Box::new(models::phi2::Model::new(
1242 &cfg,
1243 vb,
1244 self.is_gptx(config)?,
1245 normal_loading_metadata,
1246 attention_mechanism,
1247 )?))
1248 }
1249 fn load_xlora(
1250 &self,
1251 config: &str,
1252 vb: ShardedVarBuilder,
1253 lora_config: &[((String, String), LoraConfig)],
1254 xlora_config: Option<XLoraConfig>,
1255 xlora_ordering: Ordering,
1256 normal_loading_metadata: NormalLoadingMetadata,
1257 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1258 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1259 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1260
1261 Ok(Box::new(xlora_models::XLoraPhi2::new(
1262 &cfg,
1263 vb,
1264 lora_config,
1265 xlora_config,
1266 xlora_ordering,
1267 self.is_gptx(config)?,
1268 normal_loading_metadata,
1269 preload_adapters,
1270 )?))
1271 }
1272 fn is_gptx(&self, _: &str) -> Result<bool> {
1273 Ok(true)
1274 }
1275 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1276 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1277
1278 Ok(Box::new(cfg))
1279 }
1280}
1281
1282impl IsqModelLoader for Phi2Loader {
1283 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1284 Ok(vec![
1285 Regex::new(r"lm_head\.(weight|bias)$")?,
1286 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1288 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1289 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1290 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1291 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1293 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1294 ])
1295 }
1296 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1297 self.isq_layer_regexes(config)
1298 }
1299}
1300
1301impl DeviceMappedModelLoader for Phi2Loader {
1302 fn mapped_max_act_size_elems(
1303 &self,
1304 config: &str,
1305 params: &AutoDeviceMapParams,
1306 ) -> Result<usize> {
1307 let AutoDeviceMapParams::Text {
1308 max_seq_len,
1309 max_batch_size,
1310 } = params
1311 else {
1312 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1313 };
1314
1315 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1316
1317 Ok(
1318 max_batch_size
1319 * cfg.num_attention_heads
1320 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1321 )
1322 }
1323 fn non_mapped_max_act_size_elems(
1324 &self,
1325 _config: &str,
1326 _params: &AutoDeviceMapParams,
1327 ) -> Result<usize> {
1328 Ok(0)
1329 }
1330
1331 fn non_mapped_size_in_bytes(
1332 &self,
1333 config: &str,
1334 dtype: DType,
1335 weight_pack_factor: usize,
1336 _matformer_config: Option<&MatformerSliceConfig>,
1337 ) -> Result<usize> {
1338 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1339
1340 let elems = {
1341 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1342 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1344 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1345 } else {
1346 0
1347 };
1348 let norm = cfg.hidden_size;
1349 embed_tokens + lm_head + norm
1350 };
1351 Ok(elems * dtype.size_in_bytes())
1352 }
1353
1354 fn layer_sizes_in_bytes(
1355 &self,
1356 config: &str,
1357 dtype: DType,
1358 weight_pack_factor: usize,
1359 _matformer_config: Option<&MatformerSliceConfig>,
1360 ) -> Result<Vec<usize>> {
1361 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1362
1363 let per_layer_elems = {
1364 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1365
1366 let size_in = cfg.hidden_size;
1367 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1368 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1369 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1370 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1371 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1372 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1373 let (q_norm, k_norm) = if cfg.qk_layernorm {
1374 (cfg.head_dim(), cfg.head_dim())
1375 } else {
1376 (0, 0)
1377 };
1378
1379 let h_size = cfg.hidden_size;
1380 let i_size = cfg.intermediate_size;
1381 let fc1 = h_size * i_size / weight_pack_factor;
1382 let fc2 = h_size * i_size / weight_pack_factor;
1383
1384 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1385 };
1386 Ok(vec![
1387 per_layer_elems * dtype.size_in_bytes();
1388 cfg.num_hidden_layers
1389 ])
1390 }
1391
1392 fn num_layers(&self, config: &str) -> Result<usize> {
1393 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1394
1395 Ok(cfg.num_hidden_layers)
1396 }
1397
1398 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1399 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1400
1401 let cfg = ModelConfigMetadata {
1402 max_seq_len: cfg.max_position_embeddings,
1403 num_layers: cfg.num_hidden_layers,
1404 hidden_size: cfg.hidden_size,
1405 num_kv_heads: cfg.num_key_value_heads(),
1406 num_attn_heads: cfg.num_attention_heads,
1407 sliding_window: None,
1408 k_head_dim: cfg.head_dim(),
1409 v_head_dim: cfg.head_dim(),
1410 };
1411
1412 Ok(Box::new(cfg))
1413 }
1414}
1415
1416pub struct Phi3Loader;
1422
1423impl NormalModelLoader for Phi3Loader {
1424 fn load(
1425 &self,
1426 config: &str,
1427 vb: ShardedVarBuilder,
1428 normal_loading_metadata: NormalLoadingMetadata,
1429 attention_mechanism: AttentionImplementation,
1430 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1431 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1432
1433 Ok(Box::new(models::phi3::Model::new(
1434 &cfg,
1435 vb,
1436 self.is_gptx(config)?,
1437 normal_loading_metadata,
1438 attention_mechanism,
1439 )?))
1440 }
1441 fn load_xlora(
1442 &self,
1443 config: &str,
1444 vb: ShardedVarBuilder,
1445 lora_config: &[((String, String), LoraConfig)],
1446 xlora_config: Option<XLoraConfig>,
1447 xlora_ordering: Ordering,
1448 normal_loading_metadata: NormalLoadingMetadata,
1449 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1450 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1451 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1452
1453 Ok(Box::new(xlora_models::XLoraPhi3::new(
1454 &cfg,
1455 vb,
1456 lora_config,
1457 xlora_config,
1458 xlora_ordering,
1459 self.is_gptx(config)?,
1460 normal_loading_metadata,
1461 preload_adapters,
1462 )?))
1463 }
1464 fn is_gptx(&self, _: &str) -> Result<bool> {
1465 Ok(true)
1466 }
1467 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1468 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1469
1470 Ok(Box::new(cfg))
1471 }
1472}
1473
1474impl IsqModelLoader for Phi3Loader {
1475 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1476 Ok(vec![
1477 Regex::new(r"lm_head\.(weight|bias)$")?,
1478 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1480 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1481 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1483 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1484 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1485 ])
1486 }
1487 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1488 self.isq_layer_regexes(config)
1489 }
1490}
1491
1492impl DeviceMappedModelLoader for Phi3Loader {
1493 fn mapped_max_act_size_elems(
1494 &self,
1495 config: &str,
1496 params: &AutoDeviceMapParams,
1497 ) -> Result<usize> {
1498 let AutoDeviceMapParams::Text {
1499 max_seq_len,
1500 max_batch_size,
1501 } = params
1502 else {
1503 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1504 };
1505
1506 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1507
1508 Ok(
1509 max_batch_size
1510 * cfg.num_attention_heads
1511 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1512 )
1513 }
1514 fn non_mapped_max_act_size_elems(
1515 &self,
1516 _config: &str,
1517 _params: &AutoDeviceMapParams,
1518 ) -> Result<usize> {
1519 Ok(0)
1520 }
1521
1522 fn non_mapped_size_in_bytes(
1523 &self,
1524 config: &str,
1525 dtype: DType,
1526 weight_pack_factor: usize,
1527 _matformer_config: Option<&MatformerSliceConfig>,
1528 ) -> Result<usize> {
1529 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1530
1531 let elems = {
1532 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1533 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1535 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1536 } else {
1537 0
1538 };
1539 let norm = cfg.hidden_size;
1540 embed_tokens + lm_head + norm
1541 };
1542 Ok(elems * dtype.size_in_bytes())
1543 }
1544
1545 fn layer_sizes_in_bytes(
1546 &self,
1547 config: &str,
1548 dtype: DType,
1549 weight_pack_factor: usize,
1550 _matformer_config: Option<&MatformerSliceConfig>,
1551 ) -> Result<Vec<usize>> {
1552 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1553
1554 let per_layer_elems = {
1555 let input_layernorm = cfg.hidden_size;
1556 let post_attention_layernorm = cfg.hidden_size;
1557
1558 let size_in = cfg.hidden_size;
1559 let head_dim = cfg.head_dim();
1560 let op_size =
1561 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1562 let qkv_proj = size_in * op_size / weight_pack_factor;
1563 let o_proj =
1564 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1565
1566 let h_size = cfg.hidden_size;
1567 let i_size = cfg.intermediate_size;
1568 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1569 let down_proj = h_size * i_size / weight_pack_factor;
1570
1571 input_layernorm
1572 + post_attention_layernorm
1573 + qkv_proj
1574 + o_proj
1575 + gate_up_proj
1576 + down_proj
1577 };
1578 Ok(vec![
1579 per_layer_elems * dtype.size_in_bytes();
1580 cfg.num_hidden_layers
1581 ])
1582 }
1583
1584 fn num_layers(&self, config: &str) -> Result<usize> {
1585 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1586
1587 Ok(cfg.num_hidden_layers)
1588 }
1589
1590 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1591 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1592
1593 let cfg = ModelConfigMetadata {
1594 max_seq_len: cfg.max_position_embeddings,
1595 num_layers: cfg.num_hidden_layers,
1596 hidden_size: cfg.hidden_size,
1597 num_kv_heads: cfg.num_key_value_heads,
1598 num_attn_heads: cfg.num_attention_heads,
1599 sliding_window: cfg.sliding_window,
1600 k_head_dim: cfg.head_dim(),
1601 v_head_dim: cfg.head_dim(),
1602 };
1603
1604 Ok(Box::new(cfg))
1605 }
1606}
1607
1608pub struct Qwen2Loader;
1614
1615impl NormalModelLoader for Qwen2Loader {
1616 fn load(
1617 &self,
1618 config: &str,
1619 vb: ShardedVarBuilder,
1620 normal_loading_metadata: NormalLoadingMetadata,
1621 attention_mechanism: AttentionImplementation,
1622 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1623 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1624
1625 Ok(Box::new(models::qwen2::Model::new(
1626 &cfg,
1627 vb,
1628 self.is_gptx(config)?,
1629 normal_loading_metadata,
1630 attention_mechanism,
1631 )?))
1632 }
1633 fn load_xlora(
1634 &self,
1635 _config: &str,
1636 _vb: ShardedVarBuilder,
1637 _lora_config: &[((String, String), LoraConfig)],
1638 _xlora_config: Option<XLoraConfig>,
1639 _xlora_ordering: Ordering,
1640 _normal_loading_metadata: NormalLoadingMetadata,
1641 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1642 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1643 todo!()
1644 }
1645 fn is_gptx(&self, _: &str) -> Result<bool> {
1646 Ok(true)
1647 }
1648 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1649 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1650
1651 Ok(Box::new(cfg))
1652 }
1653}
1654
1655impl IsqModelLoader for Qwen2Loader {
1656 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1657 Ok(vec![
1658 Regex::new(r"lm_head\.(weight|bias)$")?,
1659 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1661 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1662 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1663 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1664 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1666 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1667 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1668 ])
1669 }
1670 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1671 self.isq_layer_regexes(config)
1672 }
1673}
1674
1675impl DeviceMappedModelLoader for Qwen2Loader {
1676 fn mapped_max_act_size_elems(
1677 &self,
1678 config: &str,
1679 params: &AutoDeviceMapParams,
1680 ) -> Result<usize> {
1681 let AutoDeviceMapParams::Text {
1682 max_seq_len,
1683 max_batch_size,
1684 } = params
1685 else {
1686 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1687 };
1688
1689 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1690
1691 Ok(
1692 max_batch_size
1693 * cfg.num_attention_heads
1694 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1695 )
1696 }
1697 fn non_mapped_max_act_size_elems(
1698 &self,
1699 _config: &str,
1700 _params: &AutoDeviceMapParams,
1701 ) -> Result<usize> {
1702 Ok(0)
1703 }
1704
1705 fn non_mapped_size_in_bytes(
1706 &self,
1707 config: &str,
1708 dtype: DType,
1709 weight_pack_factor: usize,
1710 _matformer_config: Option<&MatformerSliceConfig>,
1711 ) -> Result<usize> {
1712 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1713
1714 let elems = {
1715 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1716 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1718 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1719 } else {
1720 0
1721 };
1722 let norm = cfg.hidden_size;
1723 embed_tokens + lm_head + norm
1724 };
1725 Ok(elems * dtype.size_in_bytes())
1726 }
1727
1728 fn layer_sizes_in_bytes(
1729 &self,
1730 config: &str,
1731 dtype: DType,
1732 weight_pack_factor: usize,
1733 _matformer_config: Option<&MatformerSliceConfig>,
1734 ) -> Result<Vec<usize>> {
1735 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1736
1737 let per_layer_elems = {
1738 let input_layernorm = cfg.hidden_size;
1739 let post_attention_layernorm = cfg.hidden_size;
1740
1741 let size_in = cfg.hidden_size;
1742 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1743 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1744 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1745 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1746 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1747 let o_proj = size_q * size_in / weight_pack_factor;
1748
1749 let h_size = cfg.hidden_size;
1750 let i_size = cfg.intermediate_size;
1751 let gate_proj = h_size * i_size / weight_pack_factor;
1752 let up_proj = h_size * i_size / weight_pack_factor;
1753 let down_proj = i_size * h_size / weight_pack_factor;
1754
1755 input_layernorm
1756 + post_attention_layernorm
1757 + q_proj
1758 + k_proj
1759 + v_proj
1760 + o_proj
1761 + gate_proj
1762 + up_proj
1763 + down_proj
1764 };
1765 Ok(vec![
1766 per_layer_elems * dtype.size_in_bytes();
1767 cfg.num_hidden_layers
1768 ])
1769 }
1770
1771 fn num_layers(&self, config: &str) -> Result<usize> {
1772 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1773
1774 Ok(cfg.num_hidden_layers)
1775 }
1776
1777 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1778 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1779
1780 let cfg = ModelConfigMetadata {
1781 max_seq_len: cfg.max_position_embeddings,
1782 num_layers: cfg.num_hidden_layers,
1783 hidden_size: cfg.hidden_size,
1784 num_kv_heads: cfg.num_key_value_heads,
1785 num_attn_heads: cfg.num_attention_heads,
1786 sliding_window: cfg.sliding_window,
1787 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1788 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1789 };
1790
1791 Ok(Box::new(cfg))
1792 }
1793}
1794
1795pub struct Gemma2Loader;
1801
1802impl NormalModelLoader for Gemma2Loader {
1803 fn load(
1804 &self,
1805 config: &str,
1806 vb: ShardedVarBuilder,
1807 normal_loading_metadata: NormalLoadingMetadata,
1808 attention_mechanism: AttentionImplementation,
1809 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1810 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1811
1812 Ok(Box::new(models::gemma2::Model::new(
1813 &cfg,
1814 vb,
1815 self.is_gptx(config)?,
1816 normal_loading_metadata,
1817 attention_mechanism,
1818 )?))
1819 }
1820 fn load_xlora(
1821 &self,
1822 config: &str,
1823 vb: ShardedVarBuilder,
1824 lora_config: &[((String, String), LoraConfig)],
1825 xlora_config: Option<XLoraConfig>,
1826 xlora_ordering: Ordering,
1827 normal_loading_metadata: NormalLoadingMetadata,
1828 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1829 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1830 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1831
1832 Ok(Box::new(xlora_models::XLoraGemma2::new(
1833 &cfg,
1834 vb,
1835 lora_config,
1836 xlora_config,
1837 xlora_ordering,
1838 self.is_gptx(config)?,
1839 normal_loading_metadata,
1840 preload_adapters,
1841 )?))
1842 }
1843 fn is_gptx(&self, _: &str) -> Result<bool> {
1844 Ok(true)
1845 }
1846 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1847 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1848
1849 Ok(Box::new(cfg))
1850 }
1851}
1852
1853impl IsqModelLoader for Gemma2Loader {
1854 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1855 Ok(vec![
1856 Regex::new(r"lm_head\.(weight|bias)$")?,
1857 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1859 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1860 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1861 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1862 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1864 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1865 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1866 ])
1867 }
1868 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1869 self.isq_layer_regexes(config)
1870 }
1871}
1872
1873impl DeviceMappedModelLoader for Gemma2Loader {
1874 fn mapped_max_act_size_elems(
1875 &self,
1876 config: &str,
1877 params: &AutoDeviceMapParams,
1878 ) -> Result<usize> {
1879 let AutoDeviceMapParams::Text {
1880 max_seq_len,
1881 max_batch_size,
1882 } = params
1883 else {
1884 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1885 };
1886
1887 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1888
1889 Ok(
1890 max_batch_size
1891 * cfg.num_attention_heads
1892 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1893 )
1894 }
1895 fn non_mapped_max_act_size_elems(
1896 &self,
1897 _config: &str,
1898 _params: &AutoDeviceMapParams,
1899 ) -> Result<usize> {
1900 Ok(0)
1901 }
1902
1903 fn non_mapped_size_in_bytes(
1904 &self,
1905 config: &str,
1906 dtype: DType,
1907 weight_pack_factor: usize,
1908 _matformer_config: Option<&MatformerSliceConfig>,
1909 ) -> Result<usize> {
1910 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1911
1912 let elems = {
1913 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1914 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1916 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1917 } else {
1918 0
1919 };
1920 let norm = cfg.hidden_size;
1921 embed_tokens + lm_head + norm
1922 };
1923 Ok(elems * dtype.size_in_bytes())
1924 }
1925
1926 fn layer_sizes_in_bytes(
1927 &self,
1928 config: &str,
1929 dtype: DType,
1930 weight_pack_factor: usize,
1931 _matformer_config: Option<&MatformerSliceConfig>,
1932 ) -> Result<Vec<usize>> {
1933 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1934
1935 let per_layer_elems = {
1936 let input_layernorm = cfg.hidden_size;
1937 let post_attention_layernorm = cfg.hidden_size;
1938
1939 let size_in = cfg.hidden_size;
1940 let size_q = cfg.head_dim * cfg.num_attention_heads;
1941 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1942 let q_proj =
1943 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1944 let k_proj =
1945 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1946 let v_proj =
1947 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1948 let o_proj =
1949 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1950
1951 let h_size = cfg.hidden_size;
1952 let i_size = cfg.intermediate_size;
1953 let gate_proj = h_size * i_size / weight_pack_factor;
1954 let up_proj = h_size * i_size / weight_pack_factor;
1955 let down_proj = i_size * h_size / weight_pack_factor;
1956
1957 input_layernorm
1958 + post_attention_layernorm
1959 + q_proj
1960 + k_proj
1961 + v_proj
1962 + o_proj
1963 + gate_proj
1964 + up_proj
1965 + down_proj
1966 };
1967 Ok(vec![
1968 per_layer_elems * dtype.size_in_bytes();
1969 cfg.num_hidden_layers
1970 ])
1971 }
1972
1973 fn num_layers(&self, config: &str) -> Result<usize> {
1974 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1975
1976 Ok(cfg.num_hidden_layers)
1977 }
1978 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1979 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1980
1981 let cfg = ModelConfigMetadata {
1982 max_seq_len: cfg.max_position_embeddings,
1983 num_layers: cfg.num_hidden_layers,
1984 hidden_size: cfg.hidden_size,
1985 num_kv_heads: cfg.num_key_value_heads,
1986 num_attn_heads: cfg.num_attention_heads,
1987 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1989 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1990 };
1991
1992 Ok(Box::new(cfg))
1993 }
1994}
1995
1996pub struct Starcoder2Loader;
2002
2003impl NormalModelLoader for Starcoder2Loader {
2004 fn load(
2005 &self,
2006 config: &str,
2007 vb: ShardedVarBuilder,
2008 normal_loading_metadata: NormalLoadingMetadata,
2009 attention_mechanism: AttentionImplementation,
2010 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2011 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2012
2013 Ok(Box::new(models::starcoder2::Model::new(
2014 &cfg,
2015 vb,
2016 self.is_gptx(config)?,
2017 normal_loading_metadata,
2018 attention_mechanism,
2019 )?))
2020 }
2021 fn load_xlora(
2022 &self,
2023 config: &str,
2024 vb: ShardedVarBuilder,
2025 lora_config: &[((String, String), LoraConfig)],
2026 xlora_config: Option<XLoraConfig>,
2027 xlora_ordering: Ordering,
2028 normal_loading_metadata: NormalLoadingMetadata,
2029 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2030 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2031 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2032
2033 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2034 &cfg,
2035 vb,
2036 lora_config,
2037 xlora_config,
2038 xlora_ordering,
2039 self.is_gptx(config)?,
2040 normal_loading_metadata,
2041 preload_adapters,
2042 )?))
2043 }
2044 fn is_gptx(&self, _: &str) -> Result<bool> {
2045 Ok(true)
2046 }
2047 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2048 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2049
2050 Ok(Box::new(cfg))
2051 }
2052}
2053
2054impl IsqModelLoader for Starcoder2Loader {
2055 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2056 Ok(vec![
2057 Regex::new(r"lm_head\.(weight|bias)$")?,
2058 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2060 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2061 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2062 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2063 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2065 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2066 ])
2067 }
2068 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2069 self.isq_layer_regexes(config)
2070 }
2071}
2072
2073impl DeviceMappedModelLoader for Starcoder2Loader {
2074 fn mapped_max_act_size_elems(
2075 &self,
2076 config: &str,
2077 params: &AutoDeviceMapParams,
2078 ) -> Result<usize> {
2079 let AutoDeviceMapParams::Text {
2080 max_seq_len,
2081 max_batch_size,
2082 } = params
2083 else {
2084 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2085 };
2086
2087 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2088
2089 Ok(
2090 max_batch_size
2091 * cfg.num_attention_heads
2092 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2093 )
2094 }
2095 fn non_mapped_max_act_size_elems(
2096 &self,
2097 _config: &str,
2098 _params: &AutoDeviceMapParams,
2099 ) -> Result<usize> {
2100 Ok(0)
2101 }
2102
2103 fn non_mapped_size_in_bytes(
2104 &self,
2105 config: &str,
2106 dtype: DType,
2107 weight_pack_factor: usize,
2108 _matformer_config: Option<&MatformerSliceConfig>,
2109 ) -> Result<usize> {
2110 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2111
2112 let elems = {
2113 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2114 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2116 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2117 } else {
2118 0
2119 };
2120 let norm = cfg.hidden_size + cfg.hidden_size;
2121 embed_tokens + lm_head + norm
2122 };
2123 Ok(elems * dtype.size_in_bytes())
2124 }
2125
2126 fn layer_sizes_in_bytes(
2127 &self,
2128 config: &str,
2129 dtype: DType,
2130 weight_pack_factor: usize,
2131 _matformer_config: Option<&MatformerSliceConfig>,
2132 ) -> Result<Vec<usize>> {
2133 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2134
2135 let per_layer_elems = {
2136 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2137 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2138
2139 let size_in = cfg.hidden_size;
2140 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2141 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2142 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2143 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2144 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2145 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2146
2147 let h_size = cfg.hidden_size;
2148 let i_size = cfg.intermediate_size;
2149 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2150 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2151
2152 input_layernorm
2153 + post_attention_layernorm
2154 + q_proj
2155 + k_proj
2156 + v_proj
2157 + o_proj
2158 + fc1
2159 + fc2
2160 };
2161 Ok(vec![
2162 per_layer_elems * dtype.size_in_bytes();
2163 cfg.num_hidden_layers
2164 ])
2165 }
2166
2167 fn num_layers(&self, config: &str) -> Result<usize> {
2168 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2169
2170 Ok(cfg.num_hidden_layers)
2171 }
2172
2173 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2174 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2175
2176 let cfg = ModelConfigMetadata {
2177 max_seq_len: cfg.max_position_embeddings,
2178 num_layers: cfg.num_hidden_layers,
2179 hidden_size: cfg.hidden_size,
2180 num_kv_heads: cfg.num_key_value_heads,
2181 num_attn_heads: cfg.num_attention_heads,
2182 sliding_window: cfg.sliding_window,
2183 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2184 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2185 };
2186
2187 Ok(Box::new(cfg))
2188 }
2189}
2190
2191pub struct Phi3_5MoELoader;
2197
2198impl NormalModelLoader for Phi3_5MoELoader {
2199 fn load(
2200 &self,
2201 config: &str,
2202 vb: ShardedVarBuilder,
2203 normal_loading_metadata: NormalLoadingMetadata,
2204 attention_mechanism: AttentionImplementation,
2205 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2206 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2207
2208 Ok(Box::new(models::phi3_5_moe::Model::new(
2209 &cfg,
2210 vb,
2211 self.is_gptx(config)?,
2212 normal_loading_metadata,
2213 attention_mechanism,
2214 )?))
2215 }
2216 fn load_xlora(
2217 &self,
2218 config: &str,
2219 vb: ShardedVarBuilder,
2220 lora_config: &[((String, String), LoraConfig)],
2221 xlora_config: Option<XLoraConfig>,
2222 xlora_ordering: Ordering,
2223 normal_loading_metadata: NormalLoadingMetadata,
2224 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2225 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2226 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2227
2228 Ok(Box::new(xlora_models::XLoraPhi3::new(
2229 &cfg,
2230 vb,
2231 lora_config,
2232 xlora_config,
2233 xlora_ordering,
2234 self.is_gptx(config)?,
2235 normal_loading_metadata,
2236 preload_adapters,
2237 )?))
2238 }
2239 fn is_gptx(&self, _: &str) -> Result<bool> {
2240 Ok(true)
2241 }
2242 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2243 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2244
2245 Ok(Box::new(cfg))
2246 }
2247}
2248
2249impl IsqModelLoader for Phi3_5MoELoader {
2250 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2251 Ok(vec![
2252 Regex::new(r"lm_head\.(weight|bias)$")?,
2253 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2255 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2256 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2257 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2258 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2260 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2261 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2262 ])
2263 }
2264 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2265 self.isq_layer_regexes(config)
2266 }
2267
2268 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2269 Ok(vec![
2270 Regex::new(r"lm_head\.(weight|bias)$")?,
2271 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2273 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2274 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2275 ])
2276 }
2277 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2278 self.isq_layer_regexes_moqe(config)
2279 }
2280}
2281
2282impl DeviceMappedModelLoader for Phi3_5MoELoader {
2283 fn mapped_max_act_size_elems(
2284 &self,
2285 config: &str,
2286 params: &AutoDeviceMapParams,
2287 ) -> Result<usize> {
2288 let AutoDeviceMapParams::Text {
2289 max_seq_len,
2290 max_batch_size,
2291 } = params
2292 else {
2293 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2294 };
2295
2296 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2297
2298 Ok(
2299 max_batch_size
2300 * cfg.num_attention_heads
2301 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2302 )
2303 }
2304 fn non_mapped_max_act_size_elems(
2305 &self,
2306 _config: &str,
2307 _params: &AutoDeviceMapParams,
2308 ) -> Result<usize> {
2309 Ok(0)
2310 }
2311
2312 fn non_mapped_size_in_bytes(
2313 &self,
2314 config: &str,
2315 dtype: DType,
2316 weight_pack_factor: usize,
2317 _matformer_config: Option<&MatformerSliceConfig>,
2318 ) -> Result<usize> {
2319 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2320
2321 let elems = {
2322 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2323 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2325 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2326 } else {
2327 0
2328 };
2329 let norm = cfg.hidden_size;
2330 embed_tokens + lm_head + norm
2331 };
2332 Ok(elems * dtype.size_in_bytes())
2333 }
2334
2335 fn layer_sizes_in_bytes(
2336 &self,
2337 config: &str,
2338 dtype: DType,
2339 weight_pack_factor: usize,
2340 _matformer_config: Option<&MatformerSliceConfig>,
2341 ) -> Result<Vec<usize>> {
2342 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2343
2344 let per_layer_elems = {
2345 let input_layernorm = cfg.hidden_size;
2346 let post_attention_layernorm = cfg.hidden_size;
2347
2348 let size_in = cfg.hidden_size;
2349 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2350 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2351 let q_proj =
2352 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2353 let k_proj =
2354 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2355 let v_proj =
2356 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2357 let o_proj =
2358 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2359
2360 let moe_block = {
2361 let gate = cfg.hidden_size * cfg.num_local_experts;
2362 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2364 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2365 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2366 gate + cfg.num_local_experts * w1
2367 + cfg.num_local_experts * w2
2368 + cfg.num_local_experts * w3
2369 };
2370
2371 input_layernorm
2372 + post_attention_layernorm
2373 + q_proj
2374 + k_proj
2375 + v_proj
2376 + o_proj
2377 + moe_block
2378 };
2379 Ok(vec![
2380 per_layer_elems * dtype.size_in_bytes();
2381 cfg.num_hidden_layers
2382 ])
2383 }
2384
2385 fn num_layers(&self, config: &str) -> Result<usize> {
2386 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2387
2388 Ok(cfg.num_hidden_layers)
2389 }
2390
2391 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2392 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2393
2394 let cfg = ModelConfigMetadata {
2395 max_seq_len: cfg.max_position_embeddings,
2396 num_layers: cfg.num_hidden_layers,
2397 hidden_size: cfg.hidden_size,
2398 num_kv_heads: cfg.num_key_value_heads,
2399 num_attn_heads: cfg.num_attention_heads,
2400 sliding_window: cfg.sliding_window,
2401 k_head_dim: cfg.head_dim(),
2402 v_head_dim: cfg.head_dim(),
2403 };
2404
2405 Ok(Box::new(cfg))
2406 }
2407}
2408
2409pub struct DeepSeekV2Loader;
2413
2414impl NormalModelLoader for DeepSeekV2Loader {
2415 fn load(
2416 &self,
2417 config: &str,
2418 vb: ShardedVarBuilder,
2419 normal_loading_metadata: NormalLoadingMetadata,
2420 attention_mechanism: AttentionImplementation,
2421 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2422 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2423
2424 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2425 &cfg,
2426 vb,
2427 self.is_gptx(config)?,
2428 normal_loading_metadata,
2429 attention_mechanism,
2430 )?))
2431 }
2432 fn load_xlora(
2433 &self,
2434 _config: &str,
2435 _vb: ShardedVarBuilder,
2436 _lora_config: &[((String, String), LoraConfig)],
2437 _xlora_config: Option<XLoraConfig>,
2438 _xlora_ordering: Ordering,
2439 _normal_loading_metadata: NormalLoadingMetadata,
2440 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2441 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2442 todo!()
2443 }
2444 fn is_gptx(&self, _: &str) -> Result<bool> {
2445 Ok(true)
2446 }
2447 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2448 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2449 Ok(Box::new(cfg))
2450 }
2451}
2452
2453impl IsqModelLoader for DeepSeekV2Loader {
2454 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2455 let mut data = vec![
2456 Regex::new(r"lm_head\.(weight|bias)$")?,
2457 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2459 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2460 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2461 ];
2462 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2463 if cfg.q_lora_rank.is_some() {
2464 data.extend(vec![
2465 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2466 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2467 ]);
2468 } else {
2469 data.push(Regex::new(
2470 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2471 )?);
2472 }
2473 for layer_idx in 0..cfg.num_hidden_layers {
2474 if cfg.n_routed_experts.is_some()
2475 && layer_idx >= cfg.first_k_dense_replace
2476 && layer_idx % cfg.moe_layer_freq == 0
2477 {
2478 for i in 0..cfg.n_routed_experts.unwrap() {
2479 data.extend(vec![
2480 Regex::new(&format!(
2481 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2482 ))?,
2483 Regex::new(&format!(
2484 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2485 ))?,
2486 Regex::new(&format!(
2487 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2488 ))?,
2489 ]);
2490 }
2491 if cfg.n_shared_experts.is_some() {
2492 data.extend(vec![
2493 Regex::new(&format!(
2494 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2495 ))?,
2496 Regex::new(&format!(
2497 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2498 ))?,
2499 Regex::new(&format!(
2500 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2501 ))?,
2502 ]);
2503 }
2504 } else {
2505 data.extend(vec![
2506 Regex::new(&format!(
2507 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2508 ))?,
2509 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2510 Regex::new(&format!(
2511 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2512 ))?,
2513 ]);
2514 };
2515 }
2516 Ok(data)
2517 }
2518 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2519 self.isq_layer_regexes(config)
2520 }
2521
2522 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2523 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2524 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2525 for layer_idx in 0..cfg.num_hidden_layers {
2526 if cfg.n_routed_experts.is_some()
2527 && layer_idx >= cfg.first_k_dense_replace
2528 && layer_idx % cfg.moe_layer_freq == 0
2529 {
2530 for i in 0..cfg.n_routed_experts.unwrap() {
2531 data.extend(vec![
2532 Regex::new(&format!(
2533 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2534 ))?,
2535 Regex::new(&format!(
2536 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2537 ))?,
2538 Regex::new(&format!(
2539 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2540 ))?,
2541 ]);
2542 }
2543 if cfg.n_shared_experts.is_some() {
2544 data.extend(vec![
2545 Regex::new(&format!(
2546 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2547 ))?,
2548 Regex::new(&format!(
2549 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2550 ))?,
2551 Regex::new(&format!(
2552 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2553 ))?,
2554 ]);
2555 }
2556 } else {
2557 data.extend(vec![
2558 Regex::new(&format!(
2559 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2560 ))?,
2561 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2562 Regex::new(&format!(
2563 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2564 ))?,
2565 ]);
2566 };
2567 }
2568 Ok(data)
2569 }
2570 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2571 self.isq_layer_regexes_moqe(config)
2572 }
2573}
2574
2575impl DeviceMappedModelLoader for DeepSeekV2Loader {
2576 fn mapped_max_act_size_elems(
2577 &self,
2578 config: &str,
2579 params: &AutoDeviceMapParams,
2580 ) -> Result<usize> {
2581 let AutoDeviceMapParams::Text {
2582 max_seq_len,
2583 max_batch_size,
2584 } = params
2585 else {
2586 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2587 };
2588
2589 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2590
2591 Ok(
2592 max_batch_size
2593 * cfg.num_attention_heads
2594 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2595 )
2596 }
2597 fn non_mapped_max_act_size_elems(
2598 &self,
2599 _config: &str,
2600 _params: &AutoDeviceMapParams,
2601 ) -> Result<usize> {
2602 Ok(0)
2603 }
2604
2605 fn non_mapped_size_in_bytes(
2606 &self,
2607 config: &str,
2608 dtype: DType,
2609 weight_pack_factor: usize,
2610 _matformer_config: Option<&MatformerSliceConfig>,
2611 ) -> Result<usize> {
2612 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2613 let elems = {
2614 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2615 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2617 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2618 } else {
2619 0
2620 };
2621 let norm = cfg.hidden_size;
2622 embed_tokens + lm_head + norm
2623 };
2624 Ok(elems * dtype.size_in_bytes())
2625 }
2626
2627 fn layer_sizes_in_bytes(
2628 &self,
2629 config: &str,
2630 dtype: DType,
2631 weight_pack_factor: usize,
2632 _matformer_config: Option<&MatformerSliceConfig>,
2633 ) -> Result<Vec<usize>> {
2634 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2635 let mut per_layer_elems = Vec::new();
2636
2637 for layer_idx in 0..cfg.num_hidden_layers {
2638 let input_layernorm = cfg.hidden_size;
2639 let post_attention_layernorm = cfg.hidden_size;
2640
2641 let q_proj = match cfg.q_lora_rank {
2642 Some(lora_rank) => {
2643 let a = cfg.hidden_size * lora_rank;
2644 let norm = lora_rank;
2645 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2646 a + norm + b
2647 }
2648 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2649 };
2650 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2651 / weight_pack_factor
2652 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2653 let kv_a_layernorm = cfg.kv_lora_rank;
2654 let kv_b_proj = cfg.kv_lora_rank
2655 * cfg.num_attention_heads
2656 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2657 / weight_pack_factor;
2658 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2659 / weight_pack_factor
2660 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2661
2662 let moe_block = {
2663 let mut sum = 0;
2664 if cfg.n_routed_experts.is_some()
2665 && layer_idx >= cfg.first_k_dense_replace
2666 && layer_idx % cfg.moe_layer_freq == 0
2667 {
2668 let h_size = cfg.hidden_size;
2669 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2670 * cfg.n_routed_experts.unwrap();
2671 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2672 * cfg.n_routed_experts.unwrap();
2673 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2674 * cfg.n_routed_experts.unwrap();
2675 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2676 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2677 / weight_pack_factor;
2678 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2679 / weight_pack_factor;
2680 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2681 / weight_pack_factor;
2682 gate_proj + up_proj + down_proj
2683 } else {
2684 0
2685 };
2686 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2687 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2688 } else {
2689 let h_size = cfg.hidden_size;
2690 let i_size = cfg.intermediate_size;
2691 let gate_proj = h_size * i_size / weight_pack_factor;
2692 let up_proj = h_size * i_size / weight_pack_factor;
2693 let down_proj = i_size * h_size / weight_pack_factor;
2694 sum += gate_proj + up_proj + down_proj;
2695 }
2696 sum
2697 };
2698
2699 per_layer_elems.push(
2700 input_layernorm
2701 + post_attention_layernorm
2702 + q_proj
2703 + kv_a_layernorm
2704 + kv_a_proj_with_mqa
2705 + kv_b_proj
2706 + o_proj
2707 + moe_block,
2708 );
2709 }
2710
2711 Ok(per_layer_elems
2712 .into_iter()
2713 .map(|x| x * dtype.size_in_bytes())
2714 .collect())
2715 }
2716
2717 fn num_layers(&self, config: &str) -> Result<usize> {
2718 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2719 Ok(cfg.num_hidden_layers)
2720 }
2721
2722 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2723 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2724
2725 let cfg = ModelConfigMetadata {
2726 max_seq_len: cfg.max_position_embeddings,
2727 num_layers: cfg.num_hidden_layers,
2728 hidden_size: cfg.hidden_size,
2729 num_kv_heads: cfg.num_attention_heads,
2730 num_attn_heads: cfg.num_attention_heads,
2731 sliding_window: None,
2732 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2733 v_head_dim: cfg.v_head_dim,
2734 };
2735
2736 Ok(Box::new(cfg))
2737 }
2738}
2739
2740pub struct DeepSeekV3Loader;
2744
2745impl NormalModelLoader for DeepSeekV3Loader {
2746 fn load(
2747 &self,
2748 config: &str,
2749 vb: ShardedVarBuilder,
2750 normal_loading_metadata: NormalLoadingMetadata,
2751 attention_mechanism: AttentionImplementation,
2752 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2753 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2754 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2755 &cfg,
2756 vb,
2757 self.is_gptx(config)?,
2758 normal_loading_metadata,
2759 attention_mechanism,
2760 )?))
2761 }
2762 fn load_xlora(
2763 &self,
2764 _config: &str,
2765 _vb: ShardedVarBuilder,
2766 _lora_config: &[((String, String), LoraConfig)],
2767 _xlora_config: Option<XLoraConfig>,
2768 _xlora_ordering: Ordering,
2769 _normal_loading_metadata: NormalLoadingMetadata,
2770 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2771 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2772 todo!()
2773 }
2774 fn is_gptx(&self, _: &str) -> Result<bool> {
2775 Ok(true)
2776 }
2777 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2778 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2779 Ok(Box::new(cfg))
2780 }
2781}
2782
2783impl IsqModelLoader for DeepSeekV3Loader {
2784 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2785 let mut data = vec![
2786 Regex::new(r"lm_head\.(weight|bias)$")?,
2787 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2789 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2790 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2791 ];
2792 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2793 if cfg.q_lora_rank.is_some() {
2794 data.extend(vec![
2795 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2796 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2797 ]);
2798 } else {
2799 data.push(Regex::new(
2800 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2801 )?);
2802 }
2803 for layer_idx in 0..cfg.num_hidden_layers {
2804 if cfg.n_routed_experts.is_some()
2805 && layer_idx >= cfg.first_k_dense_replace
2806 && layer_idx % cfg.moe_layer_freq == 0
2807 {
2808 for i in 0..cfg.n_routed_experts.unwrap() {
2809 data.extend(vec![
2810 Regex::new(&format!(
2811 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2812 ))?,
2813 Regex::new(&format!(
2814 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2815 ))?,
2816 Regex::new(&format!(
2817 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2818 ))?,
2819 ]);
2820 }
2821 if cfg.n_shared_experts.is_some() {
2822 data.extend(vec![
2823 Regex::new(&format!(
2824 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2825 ))?,
2826 Regex::new(&format!(
2827 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2828 ))?,
2829 Regex::new(&format!(
2830 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2831 ))?,
2832 ]);
2833 }
2834 } else {
2835 data.extend(vec![
2836 Regex::new(&format!(
2837 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2838 ))?,
2839 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2840 Regex::new(&format!(
2841 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2842 ))?,
2843 ]);
2844 };
2845 }
2846 Ok(data)
2847 }
2848 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2849 self.isq_layer_regexes(config)
2850 }
2851
2852 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2853 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2854 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2855 for layer_idx in 0..cfg.num_hidden_layers {
2856 if cfg.n_routed_experts.is_some()
2857 && layer_idx >= cfg.first_k_dense_replace
2858 && layer_idx % cfg.moe_layer_freq == 0
2859 {
2860 for i in 0..cfg.n_routed_experts.unwrap() {
2861 data.extend(vec![
2862 Regex::new(&format!(
2863 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2864 ))?,
2865 Regex::new(&format!(
2866 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2867 ))?,
2868 Regex::new(&format!(
2869 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2870 ))?,
2871 ]);
2872 }
2873 if cfg.n_shared_experts.is_some() {
2874 data.extend(vec![
2875 Regex::new(&format!(
2876 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2877 ))?,
2878 Regex::new(&format!(
2879 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2880 ))?,
2881 Regex::new(&format!(
2882 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2883 ))?,
2884 ]);
2885 }
2886 } else {
2887 data.extend(vec![
2888 Regex::new(&format!(
2889 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2890 ))?,
2891 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2892 Regex::new(&format!(
2893 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2894 ))?,
2895 ]);
2896 };
2897 }
2898 Ok(data)
2899 }
2900 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2901 self.isq_layer_regexes_moqe(config)
2902 }
2903}
2904
2905impl DeviceMappedModelLoader for DeepSeekV3Loader {
2906 fn mapped_max_act_size_elems(
2907 &self,
2908 config: &str,
2909 params: &AutoDeviceMapParams,
2910 ) -> Result<usize> {
2911 let AutoDeviceMapParams::Text {
2912 max_seq_len,
2913 max_batch_size,
2914 } = params
2915 else {
2916 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2917 };
2918
2919 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2920
2921 Ok(
2922 max_batch_size
2923 * cfg.num_attention_heads
2924 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2925 )
2926 }
2927 fn non_mapped_max_act_size_elems(
2928 &self,
2929 _config: &str,
2930 _params: &AutoDeviceMapParams,
2931 ) -> Result<usize> {
2932 Ok(0)
2933 }
2934
2935 fn non_mapped_size_in_bytes(
2936 &self,
2937 config: &str,
2938 dtype: DType,
2939 weight_pack_factor: usize,
2940 _matformer_config: Option<&MatformerSliceConfig>,
2941 ) -> Result<usize> {
2942 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2943 let elems = {
2944 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2945 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2947 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2948 } else {
2949 0
2950 };
2951 let norm = cfg.hidden_size;
2952 embed_tokens + lm_head + norm
2953 };
2954 Ok(elems * dtype.size_in_bytes())
2955 }
2956
2957 fn layer_sizes_in_bytes(
2958 &self,
2959 config: &str,
2960 dtype: DType,
2961 weight_pack_factor: usize,
2962 _matformer_config: Option<&MatformerSliceConfig>,
2963 ) -> Result<Vec<usize>> {
2964 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2965 let mut per_layer_elems = Vec::new();
2966
2967 for layer_idx in 0..cfg.num_hidden_layers {
2968 let input_layernorm = cfg.hidden_size;
2969 let post_attention_layernorm = cfg.hidden_size;
2970
2971 let q_proj = match cfg.q_lora_rank {
2972 Some(lora_rank) => {
2973 let a = cfg.hidden_size * lora_rank;
2974 let norm = lora_rank;
2975 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2976 a + norm + b
2977 }
2978 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2979 };
2980 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2981 / weight_pack_factor
2982 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2983 let kv_a_layernorm = cfg.kv_lora_rank;
2984 let kv_b_proj = cfg.kv_lora_rank
2985 * cfg.num_attention_heads
2986 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2987 / weight_pack_factor;
2988 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2989 / weight_pack_factor
2990 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2991
2992 let moe_block = {
2993 let mut sum = 0;
2994 if cfg.n_routed_experts.is_some()
2995 && layer_idx >= cfg.first_k_dense_replace
2996 && layer_idx % cfg.moe_layer_freq == 0
2997 {
2998 let h_size = cfg.hidden_size;
2999 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3000 * cfg.n_routed_experts.unwrap();
3001 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3002 * cfg.n_routed_experts.unwrap();
3003 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3004 * cfg.n_routed_experts.unwrap();
3005 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3006 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3007 / weight_pack_factor;
3008 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3009 / weight_pack_factor;
3010 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3011 / weight_pack_factor;
3012 gate_proj + up_proj + down_proj
3013 } else {
3014 0
3015 };
3016 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3017 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3018 } else {
3019 let h_size = cfg.hidden_size;
3020 let i_size = cfg.intermediate_size;
3021 let gate_proj = h_size * i_size / weight_pack_factor;
3022 let up_proj = h_size * i_size / weight_pack_factor;
3023 let down_proj = i_size * h_size / weight_pack_factor;
3024 sum += gate_proj + up_proj + down_proj;
3025 }
3026 sum
3027 };
3028
3029 per_layer_elems.push(
3030 input_layernorm
3031 + post_attention_layernorm
3032 + q_proj
3033 + kv_a_layernorm
3034 + kv_a_proj_with_mqa
3035 + kv_b_proj
3036 + o_proj
3037 + moe_block,
3038 );
3039 }
3040
3041 Ok(per_layer_elems
3042 .into_iter()
3043 .map(|x| x * dtype.size_in_bytes())
3044 .collect())
3045 }
3046
3047 fn num_layers(&self, config: &str) -> Result<usize> {
3048 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3049 Ok(cfg.num_hidden_layers)
3050 }
3051
3052 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3053 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3054
3055 let cfg = ModelConfigMetadata {
3056 max_seq_len: cfg.max_position_embeddings,
3057 num_layers: cfg.num_hidden_layers,
3058 hidden_size: cfg.hidden_size,
3059 num_kv_heads: cfg.num_attention_heads,
3060 num_attn_heads: cfg.num_attention_heads,
3061 sliding_window: None,
3062 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3063 v_head_dim: cfg.v_head_dim,
3064 };
3065
3066 Ok(Box::new(cfg))
3067 }
3068}
3069
3070pub struct Qwen3Loader;
3074
3075impl NormalModelLoader for Qwen3Loader {
3076 fn load(
3077 &self,
3078 config: &str,
3079 vb: ShardedVarBuilder,
3080 normal_loading_metadata: NormalLoadingMetadata,
3081 attention_mechanism: AttentionImplementation,
3082 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3083 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3084
3085 Ok(Box::new(models::qwen3::Model::new(
3086 &cfg,
3087 vb,
3088 self.is_gptx(config)?,
3089 normal_loading_metadata,
3090 attention_mechanism,
3091 )?))
3092 }
3093 fn load_xlora(
3094 &self,
3095 _config: &str,
3096 _vb: ShardedVarBuilder,
3097 _lora_config: &[((String, String), LoraConfig)],
3098 _xlora_config: Option<XLoraConfig>,
3099 _xlora_ordering: Ordering,
3100 _normal_loading_metadata: NormalLoadingMetadata,
3101 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3102 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3103 todo!()
3104 }
3105 fn is_gptx(&self, _: &str) -> Result<bool> {
3106 Ok(true)
3107 }
3108 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3109 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3110
3111 Ok(Box::new(cfg))
3112 }
3113}
3114
3115impl IsqModelLoader for Qwen3Loader {
3116 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3117 Ok(vec![
3118 Regex::new(r"lm_head\.(weight|bias)$")?,
3119 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3121 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3122 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3123 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3124 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3126 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3127 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3128 ])
3129 }
3130 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3131 self.isq_layer_regexes(config)
3132 }
3133}
3134
3135impl DeviceMappedModelLoader for Qwen3Loader {
3136 fn mapped_max_act_size_elems(
3137 &self,
3138 config: &str,
3139 params: &AutoDeviceMapParams,
3140 ) -> Result<usize> {
3141 let AutoDeviceMapParams::Text {
3142 max_seq_len,
3143 max_batch_size,
3144 } = params
3145 else {
3146 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3147 };
3148
3149 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3150
3151 Ok(
3152 max_batch_size
3153 * cfg.num_attention_heads
3154 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3155 )
3156 }
3157 fn non_mapped_max_act_size_elems(
3158 &self,
3159 _config: &str,
3160 _params: &AutoDeviceMapParams,
3161 ) -> Result<usize> {
3162 Ok(0)
3163 }
3164
3165 fn non_mapped_size_in_bytes(
3166 &self,
3167 config: &str,
3168 dtype: DType,
3169 weight_pack_factor: usize,
3170 _matformer_config: Option<&MatformerSliceConfig>,
3171 ) -> Result<usize> {
3172 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3173 let elems = {
3174 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3175 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3177 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3178 } else {
3179 0
3180 };
3181 let norm = cfg.hidden_size;
3182 embed_tokens + lm_head + norm
3183 };
3184 Ok(elems * dtype.size_in_bytes())
3185 }
3186
3187 fn layer_sizes_in_bytes(
3188 &self,
3189 config: &str,
3190 dtype: DType,
3191 weight_pack_factor: usize,
3192 _matformer_config: Option<&MatformerSliceConfig>,
3193 ) -> Result<Vec<usize>> {
3194 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3195 let per_layer_elems = {
3196 let input_layernorm = cfg.hidden_size;
3197 let post_attention_layernorm = cfg.hidden_size;
3198
3199 let size_in = cfg.hidden_size;
3200 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3201 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3202 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3203 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3204 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3205 let o_proj = size_q * size_in / weight_pack_factor;
3206
3207 let h_size = cfg.hidden_size;
3208 let i_size = cfg.intermediate_size;
3209 let gate_proj = h_size * i_size / weight_pack_factor;
3210 let up_proj = h_size * i_size / weight_pack_factor;
3211 let down_proj = i_size * h_size / weight_pack_factor;
3212
3213 let q_norm = cfg.head_dim();
3214 let k_norm = cfg.head_dim();
3215
3216 input_layernorm
3217 + post_attention_layernorm
3218 + q_proj
3219 + k_proj
3220 + v_proj
3221 + o_proj
3222 + gate_proj
3223 + up_proj
3224 + down_proj
3225 + q_norm
3226 + k_norm
3227 };
3228 Ok(vec![
3229 per_layer_elems * dtype.size_in_bytes();
3230 cfg.num_hidden_layers
3231 ])
3232 }
3233
3234 fn num_layers(&self, config: &str) -> Result<usize> {
3235 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3236 Ok(cfg.num_hidden_layers)
3237 }
3238
3239 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3240 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3241
3242 let cfg = ModelConfigMetadata {
3243 max_seq_len: cfg.max_position_embeddings,
3244 num_layers: cfg.num_hidden_layers,
3245 hidden_size: cfg.hidden_size,
3246 num_kv_heads: cfg.num_key_value_heads,
3247 num_attn_heads: cfg.num_attention_heads,
3248 sliding_window: cfg.sliding_window,
3249 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3250 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3251 };
3252
3253 Ok(Box::new(cfg))
3254 }
3255}
3256
3257pub struct GLM4Loader;
3261
3262impl NormalModelLoader for GLM4Loader {
3263 fn load(
3264 &self,
3265 config: &str,
3266 vb: ShardedVarBuilder,
3267 normal_loading_metadata: NormalLoadingMetadata,
3268 attention_mechanism: AttentionImplementation,
3269 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3270 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3271
3272 Ok(Box::new(models::glm4::Model::new(
3273 &cfg,
3274 vb,
3275 self.is_gptx(config)?,
3276 normal_loading_metadata,
3277 attention_mechanism,
3278 )?))
3279 }
3280 fn load_xlora(
3281 &self,
3282 _config: &str,
3283 _vb: ShardedVarBuilder,
3284 _lora_config: &[((String, String), LoraConfig)],
3285 _xlora_config: Option<XLoraConfig>,
3286 _xlora_ordering: Ordering,
3287 _normal_loading_metadata: NormalLoadingMetadata,
3288 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3289 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3290 todo!()
3291 }
3292 fn is_gptx(&self, _: &str) -> Result<bool> {
3293 Ok(true)
3294 }
3295 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3296 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3297
3298 Ok(Box::new(cfg))
3299 }
3300}
3301
3302impl IsqModelLoader for GLM4Loader {
3303 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3304 Ok(vec![
3305 Regex::new(r"lm_head\.(weight|bias)$")?,
3306 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3308 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3309 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3310 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3311 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3313 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3314 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3315 ])
3316 }
3317 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3318 self.isq_layer_regexes(config)
3319 }
3320}
3321
3322impl DeviceMappedModelLoader for GLM4Loader {
3323 fn mapped_max_act_size_elems(
3324 &self,
3325 config: &str,
3326 params: &AutoDeviceMapParams,
3327 ) -> Result<usize> {
3328 let AutoDeviceMapParams::Text {
3329 max_seq_len,
3330 max_batch_size,
3331 } = params
3332 else {
3333 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3334 };
3335
3336 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3337
3338 Ok(
3339 max_batch_size
3340 * cfg.num_attention_heads
3341 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3342 )
3343 }
3344 fn non_mapped_max_act_size_elems(
3345 &self,
3346 _config: &str,
3347 _params: &AutoDeviceMapParams,
3348 ) -> Result<usize> {
3349 Ok(0)
3350 }
3351
3352 fn non_mapped_size_in_bytes(
3353 &self,
3354 config: &str,
3355 dtype: DType,
3356 weight_pack_factor: usize,
3357 _matformer_config: Option<&MatformerSliceConfig>,
3358 ) -> Result<usize> {
3359 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3360 let elems = {
3361 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3362 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3364 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3365 } else {
3366 0
3367 };
3368 let norm = cfg.hidden_size;
3369 embed_tokens + lm_head + norm
3370 };
3371 Ok(elems * dtype.size_in_bytes())
3372 }
3373
3374 fn layer_sizes_in_bytes(
3375 &self,
3376 config: &str,
3377 dtype: DType,
3378 weight_pack_factor: usize,
3379 _matformer_config: Option<&MatformerSliceConfig>,
3380 ) -> Result<Vec<usize>> {
3381 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3382 let per_layer_elems = {
3383 let input_layernorm = cfg.hidden_size;
3384 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3387 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3388 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3389 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3390 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3391 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3392 let o_proj = size_q * size_in / weight_pack_factor;
3393
3394 let h_size = cfg.hidden_size;
3395 let i_size = cfg.intermediate_size;
3396 let gate_proj = h_size * i_size / weight_pack_factor;
3397 let up_proj = h_size * i_size / weight_pack_factor;
3398 let down_proj = i_size * h_size / weight_pack_factor;
3399
3400 input_layernorm
3401 + post_attention_layernorm
3402 + q_proj
3403 + k_proj
3404 + v_proj
3405 + o_proj
3406 + gate_proj
3407 + up_proj
3408 + down_proj
3409 };
3410 Ok(vec![
3411 per_layer_elems * dtype.size_in_bytes();
3412 cfg.num_hidden_layers
3413 ])
3414 }
3415
3416 fn num_layers(&self, config: &str) -> Result<usize> {
3417 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3418 Ok(cfg.num_hidden_layers)
3419 }
3420
3421 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3422 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3423
3424 let cfg = ModelConfigMetadata {
3425 max_seq_len: cfg.max_position_embeddings,
3426 num_layers: cfg.num_hidden_layers,
3427 hidden_size: cfg.hidden_size,
3428 num_kv_heads: cfg.num_key_value_heads,
3429 num_attn_heads: cfg.num_attention_heads,
3430 sliding_window: cfg.sliding_window,
3431 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3432 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3433 };
3434
3435 Ok(Box::new(cfg))
3436 }
3437}
3438
3439pub struct GLM4MoeLiteLoader;
3443
3444impl NormalModelLoader for GLM4MoeLiteLoader {
3445 fn load(
3446 &self,
3447 config: &str,
3448 vb: ShardedVarBuilder,
3449 normal_loading_metadata: NormalLoadingMetadata,
3450 attention_mechanism: AttentionImplementation,
3451 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3452 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3453 Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3454 &cfg,
3455 vb,
3456 self.is_gptx(config)?,
3457 normal_loading_metadata,
3458 attention_mechanism,
3459 )?))
3460 }
3461 fn load_xlora(
3462 &self,
3463 _config: &str,
3464 _vb: ShardedVarBuilder,
3465 _lora_config: &[((String, String), LoraConfig)],
3466 _xlora_config: Option<XLoraConfig>,
3467 _xlora_ordering: Ordering,
3468 _normal_loading_metadata: NormalLoadingMetadata,
3469 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3470 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3471 todo!()
3472 }
3473 fn is_gptx(&self, _: &str) -> Result<bool> {
3474 Ok(true)
3475 }
3476 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3477 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3478 Ok(Box::new(cfg))
3479 }
3480}
3481
3482impl IsqModelLoader for GLM4MoeLiteLoader {
3483 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3484 let mut data = vec![
3485 Regex::new(r"lm_head\.(weight|bias)$")?,
3486 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3488 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3489 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3490 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3492 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3493 ];
3494 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3495 for layer_idx in 0..cfg.num_hidden_layers {
3496 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3497 for i in 0..cfg.n_routed_experts {
3499 data.extend(vec![
3500 Regex::new(&format!(
3501 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3502 ))?,
3503 Regex::new(&format!(
3504 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3505 ))?,
3506 Regex::new(&format!(
3507 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3508 ))?,
3509 ]);
3510 }
3511 if cfg.n_shared_experts > 0 {
3512 data.extend(vec![
3513 Regex::new(&format!(
3514 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3515 ))?,
3516 Regex::new(&format!(
3517 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3518 ))?,
3519 Regex::new(&format!(
3520 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3521 ))?,
3522 ]);
3523 }
3524 } else {
3525 data.extend(vec![
3527 Regex::new(&format!(
3528 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3529 ))?,
3530 Regex::new(&format!(
3531 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3532 ))?,
3533 Regex::new(&format!(
3534 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3535 ))?,
3536 ]);
3537 };
3538 }
3539 Ok(data)
3540 }
3541 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3542 self.isq_layer_regexes(config)
3543 }
3544
3545 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3546 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3547 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3548 for layer_idx in 0..cfg.num_hidden_layers {
3549 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3550 for i in 0..cfg.n_routed_experts {
3552 data.extend(vec![
3553 Regex::new(&format!(
3554 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3555 ))?,
3556 Regex::new(&format!(
3557 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3558 ))?,
3559 Regex::new(&format!(
3560 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3561 ))?,
3562 ]);
3563 }
3564 if cfg.n_shared_experts > 0 {
3565 data.extend(vec![
3566 Regex::new(&format!(
3567 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3568 ))?,
3569 Regex::new(&format!(
3570 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3571 ))?,
3572 Regex::new(&format!(
3573 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3574 ))?,
3575 ]);
3576 }
3577 } else {
3578 data.extend(vec![
3580 Regex::new(&format!(
3581 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3582 ))?,
3583 Regex::new(&format!(
3584 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3585 ))?,
3586 Regex::new(&format!(
3587 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3588 ))?,
3589 ]);
3590 };
3591 }
3592 Ok(data)
3593 }
3594 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3595 self.isq_layer_regexes_moqe(config)
3596 }
3597}
3598
3599impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3600 fn mapped_max_act_size_elems(
3601 &self,
3602 config: &str,
3603 params: &AutoDeviceMapParams,
3604 ) -> Result<usize> {
3605 let AutoDeviceMapParams::Text {
3606 max_seq_len,
3607 max_batch_size,
3608 } = params
3609 else {
3610 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3611 };
3612
3613 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3614
3615 Ok(
3616 max_batch_size
3617 * cfg.num_attention_heads
3618 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3619 )
3620 }
3621 fn non_mapped_max_act_size_elems(
3622 &self,
3623 _config: &str,
3624 _params: &AutoDeviceMapParams,
3625 ) -> Result<usize> {
3626 Ok(0)
3627 }
3628
3629 fn non_mapped_size_in_bytes(
3630 &self,
3631 config: &str,
3632 dtype: DType,
3633 weight_pack_factor: usize,
3634 _matformer_config: Option<&MatformerSliceConfig>,
3635 ) -> Result<usize> {
3636 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3637 let elems = {
3638 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3639 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3641 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3642 } else {
3643 0
3644 };
3645 let norm = cfg.hidden_size;
3646 embed_tokens + lm_head + norm
3647 };
3648 Ok(elems * dtype.size_in_bytes())
3649 }
3650
3651 fn layer_sizes_in_bytes(
3652 &self,
3653 config: &str,
3654 dtype: DType,
3655 weight_pack_factor: usize,
3656 _matformer_config: Option<&MatformerSliceConfig>,
3657 ) -> Result<Vec<usize>> {
3658 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3659 let mut per_layer_elems = Vec::new();
3660
3661 for layer_idx in 0..cfg.num_hidden_layers {
3662 let input_layernorm = cfg.hidden_size;
3663 let post_attention_layernorm = cfg.hidden_size;
3664
3665 let q_proj = {
3667 let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3668 let norm = cfg.q_lora_rank;
3669 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3670 / weight_pack_factor;
3671 a + norm + b
3672 };
3673 let kv_a_proj_with_mqa =
3674 cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3675 let kv_a_layernorm = cfg.kv_lora_rank;
3676 let kv_b_proj = cfg.kv_lora_rank
3677 * cfg.num_attention_heads
3678 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3679 / weight_pack_factor;
3680 let o_proj =
3681 cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3682
3683 let moe_block = {
3684 let mut sum = 0;
3685 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3686 let h_size = cfg.hidden_size;
3688 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3689 * cfg.n_routed_experts;
3690 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3691 * cfg.n_routed_experts;
3692 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3693 * cfg.n_routed_experts;
3694 let shared_experts = if cfg.n_shared_experts > 0 {
3695 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3696 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3697 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3698 gate_proj + up_proj + down_proj
3699 } else {
3700 0
3701 };
3702 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3703 let e_score_correction_bias = cfg.n_routed_experts;
3704 sum += gate_proj
3705 + up_proj
3706 + down_proj
3707 + shared_experts
3708 + gate_weight
3709 + e_score_correction_bias;
3710 } else {
3711 let h_size = cfg.hidden_size;
3713 let i_size = cfg.intermediate_size;
3714 let gate_proj = h_size * i_size / weight_pack_factor;
3715 let up_proj = h_size * i_size / weight_pack_factor;
3716 let down_proj = i_size * h_size / weight_pack_factor;
3717 sum += gate_proj + up_proj + down_proj;
3718 }
3719 sum
3720 };
3721
3722 per_layer_elems.push(
3723 input_layernorm
3724 + post_attention_layernorm
3725 + q_proj
3726 + kv_a_layernorm
3727 + kv_a_proj_with_mqa
3728 + kv_b_proj
3729 + o_proj
3730 + moe_block,
3731 );
3732 }
3733
3734 Ok(per_layer_elems
3735 .into_iter()
3736 .map(|x| x * dtype.size_in_bytes())
3737 .collect())
3738 }
3739
3740 fn num_layers(&self, config: &str) -> Result<usize> {
3741 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3742 Ok(cfg.num_hidden_layers)
3743 }
3744
3745 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3746 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3747
3748 let cfg = ModelConfigMetadata {
3749 max_seq_len: cfg.max_position_embeddings,
3750 num_layers: cfg.num_hidden_layers,
3751 hidden_size: cfg.hidden_size,
3752 num_kv_heads: cfg.num_attention_heads,
3753 num_attn_heads: cfg.num_attention_heads,
3754 sliding_window: None,
3755 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3756 v_head_dim: cfg.v_head_dim,
3757 };
3758
3759 Ok(Box::new(cfg))
3760 }
3761}
3762
3763pub struct GLM4MoeLoader;
3767
3768impl NormalModelLoader for GLM4MoeLoader {
3769 fn load(
3770 &self,
3771 config: &str,
3772 vb: ShardedVarBuilder,
3773 normal_loading_metadata: NormalLoadingMetadata,
3774 attention_mechanism: AttentionImplementation,
3775 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3776 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3777 Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3778 &cfg,
3779 vb,
3780 self.is_gptx(config)?,
3781 normal_loading_metadata,
3782 attention_mechanism,
3783 )?))
3784 }
3785 fn load_xlora(
3786 &self,
3787 _config: &str,
3788 _vb: ShardedVarBuilder,
3789 _lora_config: &[((String, String), LoraConfig)],
3790 _xlora_config: Option<XLoraConfig>,
3791 _xlora_ordering: Ordering,
3792 _normal_loading_metadata: NormalLoadingMetadata,
3793 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3794 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3795 todo!()
3796 }
3797 fn is_gptx(&self, _: &str) -> Result<bool> {
3798 Ok(true)
3799 }
3800 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3801 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3802 Ok(Box::new(cfg))
3803 }
3804}
3805
3806impl IsqModelLoader for GLM4MoeLoader {
3807 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3808 let mut data = vec![
3809 Regex::new(r"lm_head\.(weight|bias)$")?,
3810 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3812 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3813 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3814 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3815 ];
3816 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3817 for layer_idx in 0..cfg.num_hidden_layers {
3818 if layer_idx >= cfg.first_k_dense_replace {
3819 for i in 0..cfg.n_routed_experts {
3821 data.extend(vec![
3822 Regex::new(&format!(
3823 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3824 ))?,
3825 Regex::new(&format!(
3826 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3827 ))?,
3828 Regex::new(&format!(
3829 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3830 ))?,
3831 ]);
3832 }
3833 if cfg.n_shared_experts > 0 {
3834 data.extend(vec![
3835 Regex::new(&format!(
3836 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3837 ))?,
3838 Regex::new(&format!(
3839 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3840 ))?,
3841 Regex::new(&format!(
3842 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3843 ))?,
3844 ]);
3845 }
3846 } else {
3847 data.extend(vec![
3849 Regex::new(&format!(
3850 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3851 ))?,
3852 Regex::new(&format!(
3853 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3854 ))?,
3855 Regex::new(&format!(
3856 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3857 ))?,
3858 ]);
3859 };
3860 }
3861 Ok(data)
3862 }
3863 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3864 self.isq_layer_regexes(config)
3865 }
3866
3867 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3868 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3869 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3870 for layer_idx in 0..cfg.num_hidden_layers {
3871 if layer_idx >= cfg.first_k_dense_replace {
3872 for i in 0..cfg.n_routed_experts {
3874 data.extend(vec![
3875 Regex::new(&format!(
3876 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3877 ))?,
3878 Regex::new(&format!(
3879 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3880 ))?,
3881 Regex::new(&format!(
3882 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3883 ))?,
3884 ]);
3885 }
3886 if cfg.n_shared_experts > 0 {
3887 data.extend(vec![
3888 Regex::new(&format!(
3889 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3890 ))?,
3891 Regex::new(&format!(
3892 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3893 ))?,
3894 Regex::new(&format!(
3895 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3896 ))?,
3897 ]);
3898 }
3899 } else {
3900 data.extend(vec![
3902 Regex::new(&format!(
3903 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3904 ))?,
3905 Regex::new(&format!(
3906 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3907 ))?,
3908 Regex::new(&format!(
3909 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3910 ))?,
3911 ]);
3912 };
3913 }
3914 Ok(data)
3915 }
3916 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3917 self.isq_layer_regexes_moqe(config)
3918 }
3919}
3920
3921impl DeviceMappedModelLoader for GLM4MoeLoader {
3922 fn mapped_max_act_size_elems(
3923 &self,
3924 config: &str,
3925 params: &AutoDeviceMapParams,
3926 ) -> Result<usize> {
3927 let AutoDeviceMapParams::Text {
3928 max_seq_len,
3929 max_batch_size,
3930 } = params
3931 else {
3932 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3933 };
3934
3935 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3936
3937 Ok(
3938 max_batch_size
3939 * cfg.num_attention_heads
3940 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3941 )
3942 }
3943 fn non_mapped_max_act_size_elems(
3944 &self,
3945 _config: &str,
3946 _params: &AutoDeviceMapParams,
3947 ) -> Result<usize> {
3948 Ok(0)
3949 }
3950
3951 fn non_mapped_size_in_bytes(
3952 &self,
3953 config: &str,
3954 dtype: DType,
3955 weight_pack_factor: usize,
3956 _matformer_config: Option<&MatformerSliceConfig>,
3957 ) -> Result<usize> {
3958 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3959 let elems = {
3960 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3961 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3962 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3963 } else {
3964 0
3965 };
3966 let norm = cfg.hidden_size;
3967 embed_tokens + lm_head + norm
3968 };
3969 Ok(elems * dtype.size_in_bytes())
3970 }
3971
3972 fn layer_sizes_in_bytes(
3973 &self,
3974 config: &str,
3975 dtype: DType,
3976 weight_pack_factor: usize,
3977 _matformer_config: Option<&MatformerSliceConfig>,
3978 ) -> Result<Vec<usize>> {
3979 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3980 let mut per_layer_elems = Vec::new();
3981
3982 let head_dim = cfg.head_dim();
3983 for layer_idx in 0..cfg.num_hidden_layers {
3984 let input_layernorm = cfg.hidden_size;
3985 let post_attention_layernorm = cfg.hidden_size;
3986
3987 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
3989 + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
3990 let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
3991 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
3992 let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
3993 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
3994 let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
3995
3996 let qk_norm = if cfg.use_qk_norm {
3998 head_dim * 2 } else {
4000 0
4001 };
4002
4003 let moe_block = {
4004 let mut sum = 0;
4005 if layer_idx >= cfg.first_k_dense_replace {
4006 let h_size = cfg.hidden_size;
4008 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4009 * cfg.n_routed_experts;
4010 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4011 * cfg.n_routed_experts;
4012 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4013 * cfg.n_routed_experts;
4014 let shared_experts = if cfg.n_shared_experts > 0 {
4015 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4016 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4017 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4018 gate_proj + up_proj + down_proj
4019 } else {
4020 0
4021 };
4022 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4023 let e_score_correction_bias = cfg.n_routed_experts;
4024 sum += gate_proj
4025 + up_proj
4026 + down_proj
4027 + shared_experts
4028 + gate_weight
4029 + e_score_correction_bias;
4030 } else {
4031 let h_size = cfg.hidden_size;
4033 let i_size = cfg.intermediate_size;
4034 let gate_proj = h_size * i_size / weight_pack_factor;
4035 let up_proj = h_size * i_size / weight_pack_factor;
4036 let down_proj = i_size * h_size / weight_pack_factor;
4037 sum += gate_proj + up_proj + down_proj;
4038 }
4039 sum
4040 };
4041
4042 per_layer_elems.push(
4043 input_layernorm
4044 + post_attention_layernorm
4045 + q_proj
4046 + k_proj
4047 + v_proj
4048 + o_proj
4049 + qk_norm
4050 + moe_block,
4051 );
4052 }
4053
4054 Ok(per_layer_elems
4055 .into_iter()
4056 .map(|x| x * dtype.size_in_bytes())
4057 .collect())
4058 }
4059
4060 fn num_layers(&self, config: &str) -> Result<usize> {
4061 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4062 Ok(cfg.num_hidden_layers)
4063 }
4064
4065 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4066 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4067
4068 let head_dim = cfg.head_dim();
4069 let cfg = ModelConfigMetadata {
4070 max_seq_len: cfg.max_position_embeddings,
4071 num_layers: cfg.num_hidden_layers,
4072 hidden_size: cfg.hidden_size,
4073 num_kv_heads: cfg.num_key_value_heads,
4074 num_attn_heads: cfg.num_attention_heads,
4075 sliding_window: None,
4076 k_head_dim: head_dim,
4077 v_head_dim: head_dim,
4078 };
4079
4080 Ok(Box::new(cfg))
4081 }
4082}
4083
4084pub struct Qwen3MoELoader;
4088
4089impl NormalModelLoader for Qwen3MoELoader {
4090 fn load(
4091 &self,
4092 config: &str,
4093 vb: ShardedVarBuilder,
4094 normal_loading_metadata: NormalLoadingMetadata,
4095 attention_mechanism: AttentionImplementation,
4096 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4097 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4098
4099 Ok(Box::new(models::qwen3_moe::Model::new(
4100 &cfg,
4101 vb,
4102 self.is_gptx(config)?,
4103 normal_loading_metadata,
4104 attention_mechanism,
4105 )?))
4106 }
4107 fn load_xlora(
4108 &self,
4109 _config: &str,
4110 _vb: ShardedVarBuilder,
4111 _lora_config: &[((String, String), LoraConfig)],
4112 _xlora_config: Option<XLoraConfig>,
4113 _xlora_ordering: Ordering,
4114 _normal_loading_metadata: NormalLoadingMetadata,
4115 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4116 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4117 todo!()
4118 }
4119 fn is_gptx(&self, _: &str) -> Result<bool> {
4120 Ok(true)
4121 }
4122 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4123 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4124
4125 Ok(Box::new(cfg))
4126 }
4127}
4128
4129impl IsqModelLoader for Qwen3MoELoader {
4130 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4131 Ok(vec![
4132 Regex::new(r"lm_head\.(weight|bias)$")?,
4133 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4135 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4136 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4137 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4138 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4140 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4141 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4142 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4144 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4145 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4146 ])
4147 }
4148 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4149 self.isq_layer_regexes(config)
4150 }
4151 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4152 self.isq_layer_regexes_moqe(config)
4153 }
4154}
4155
4156impl DeviceMappedModelLoader for Qwen3MoELoader {
4157 fn mapped_max_act_size_elems(
4158 &self,
4159 config: &str,
4160 params: &AutoDeviceMapParams,
4161 ) -> Result<usize> {
4162 let AutoDeviceMapParams::Text {
4163 max_seq_len,
4164 max_batch_size,
4165 } = params
4166 else {
4167 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4168 };
4169
4170 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4171
4172 Ok(
4173 max_batch_size
4174 * cfg.num_attention_heads
4175 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4176 )
4177 }
4178 fn non_mapped_max_act_size_elems(
4179 &self,
4180 _config: &str,
4181 _params: &AutoDeviceMapParams,
4182 ) -> Result<usize> {
4183 Ok(0)
4184 }
4185
4186 fn non_mapped_size_in_bytes(
4187 &self,
4188 config: &str,
4189 dtype: DType,
4190 weight_pack_factor: usize,
4191 _matformer_config: Option<&MatformerSliceConfig>,
4192 ) -> Result<usize> {
4193 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4194 let elems = {
4195 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4196 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4198 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4199 } else {
4200 0
4201 };
4202 let norm = cfg.hidden_size;
4203 embed_tokens + lm_head + norm
4204 };
4205 Ok(elems * dtype.size_in_bytes())
4206 }
4207
4208 fn layer_sizes_in_bytes(
4209 &self,
4210 config: &str,
4211 dtype: DType,
4212 weight_pack_factor: usize,
4213 _matformer_config: Option<&MatformerSliceConfig>,
4214 ) -> Result<Vec<usize>> {
4215 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4216
4217 let mut layer_sizes_in_bytes = Vec::new();
4218 for layer_idx in 0..cfg.num_hidden_layers {
4219 let input_layernorm = cfg.hidden_size;
4220 let post_attention_layernorm = cfg.hidden_size;
4221
4222 let size_in = cfg.hidden_size;
4223 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4224 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4225 let q_proj = size_in * size_q / weight_pack_factor;
4226 let k_proj = size_in * size_kv / weight_pack_factor;
4227 let v_proj = size_in * size_kv / weight_pack_factor;
4228 let o_proj = size_q * size_in / weight_pack_factor;
4229
4230 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4231 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4232 {
4233 let gate_size = cfg.hidden_size * cfg.num_experts;
4234 let expert_size = {
4235 let h_size = cfg.hidden_size;
4236 let i_size = cfg.moe_intermediate_size;
4237 let gate_proj = h_size * i_size / weight_pack_factor;
4238 let up_proj = h_size * i_size / weight_pack_factor;
4239 let down_proj = i_size * h_size / weight_pack_factor;
4240 gate_proj + up_proj + down_proj
4241 };
4242 expert_size * cfg.num_experts + gate_size
4243 } else {
4244 let h_size = cfg.hidden_size;
4245 let i_size = cfg.intermediate_size;
4246 let gate_proj = h_size * i_size / weight_pack_factor;
4247 let up_proj = h_size * i_size / weight_pack_factor;
4248 let down_proj = i_size * h_size / weight_pack_factor;
4249 gate_proj + up_proj + down_proj
4250 };
4251
4252 let q_norm = cfg.head_dim();
4253 let k_norm = cfg.head_dim();
4254
4255 let size_elems = input_layernorm
4256 + post_attention_layernorm
4257 + q_proj
4258 + k_proj
4259 + v_proj
4260 + o_proj
4261 + mlp_size
4262 + q_norm
4263 + k_norm;
4264
4265 let size_in_bytes = size_elems * dtype.size_in_bytes();
4266 layer_sizes_in_bytes.push(size_in_bytes);
4267 }
4268
4269 Ok(layer_sizes_in_bytes)
4270 }
4271
4272 fn num_layers(&self, config: &str) -> Result<usize> {
4273 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4274 Ok(cfg.num_hidden_layers)
4275 }
4276
4277 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4278 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4279
4280 let cfg = ModelConfigMetadata {
4281 max_seq_len: cfg.max_position_embeddings,
4282 num_layers: cfg.num_hidden_layers,
4283 hidden_size: cfg.hidden_size,
4284 num_kv_heads: cfg.num_key_value_heads,
4285 num_attn_heads: cfg.num_attention_heads,
4286 sliding_window: cfg.sliding_window,
4287 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4288 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4289 };
4290
4291 Ok(Box::new(cfg))
4292 }
4293}
4294
4295pub struct SmolLm3Loader;
4301
4302impl NormalModelLoader for SmolLm3Loader {
4303 fn load(
4304 &self,
4305 config: &str,
4306 vb: ShardedVarBuilder,
4307 normal_loading_metadata: NormalLoadingMetadata,
4308 attention_mechanism: AttentionImplementation,
4309 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4310 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4311
4312 Ok(Box::new(models::smollm3::SmolLm3::new(
4313 &cfg,
4314 vb,
4315 self.is_gptx(config)?,
4316 normal_loading_metadata,
4317 attention_mechanism,
4318 )?))
4319 }
4320 fn load_xlora(
4321 &self,
4322 _config: &str,
4323 _vb: ShardedVarBuilder,
4324 _lora_config: &[((String, String), LoraConfig)],
4325 _xlora_config: Option<XLoraConfig>,
4326 _xlora_ordering: Ordering,
4327 _normal_loading_metadata: NormalLoadingMetadata,
4328 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4329 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4330 todo!()
4331 }
4332 fn is_gptx(&self, _: &str) -> Result<bool> {
4333 Ok(true)
4334 }
4335 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4336 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4337 Ok(Box::new(cfg))
4338 }
4339}
4340
4341impl IsqModelLoader for SmolLm3Loader {
4342 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4343 Ok(vec![
4344 Regex::new(r"lm_head\.(weight|bias)$")?,
4345 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4347 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4348 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4349 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4350 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4352 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4353 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4354 ])
4355 }
4356 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4357 self.isq_layer_regexes(config)
4358 }
4359}
4360
4361impl DeviceMappedModelLoader for SmolLm3Loader {
4362 fn mapped_max_act_size_elems(
4363 &self,
4364 config: &str,
4365 params: &AutoDeviceMapParams,
4366 ) -> Result<usize> {
4367 let AutoDeviceMapParams::Text {
4368 max_seq_len,
4369 max_batch_size,
4370 } = params
4371 else {
4372 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4373 };
4374
4375 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4376
4377 Ok(
4378 max_batch_size
4379 * cfg.num_attention_heads
4380 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4381 )
4382 }
4383 fn non_mapped_max_act_size_elems(
4384 &self,
4385 _config: &str,
4386 _params: &AutoDeviceMapParams,
4387 ) -> Result<usize> {
4388 Ok(0)
4389 }
4390
4391 fn non_mapped_size_in_bytes(
4392 &self,
4393 config: &str,
4394 dtype: DType,
4395 weight_pack_factor: usize,
4396 _matformer_config: Option<&MatformerSliceConfig>,
4397 ) -> Result<usize> {
4398 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4399
4400 let elems = {
4401 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4402 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4404 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4405 } else {
4406 0
4407 };
4408 let norm = cfg.hidden_size;
4409 embed_tokens + lm_head + norm
4410 };
4411 Ok(elems * dtype.size_in_bytes())
4412 }
4413
4414 fn layer_sizes_in_bytes(
4415 &self,
4416 config: &str,
4417 dtype: DType,
4418 weight_pack_factor: usize,
4419 _matformer_config: Option<&MatformerSliceConfig>,
4420 ) -> Result<Vec<usize>> {
4421 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4422
4423 let per_layer_elems = {
4424 let input_layernorm = cfg.hidden_size;
4425 let post_attention_layernorm = cfg.hidden_size;
4426
4427 let size_in = cfg.hidden_size;
4428 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4429 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4430 let q_proj = size_in * size_q / weight_pack_factor;
4431 let k_proj = size_in * size_kv / weight_pack_factor;
4432 let v_proj = size_in * size_kv / weight_pack_factor;
4433 let o_proj = size_q * size_in / weight_pack_factor;
4434
4435 let h_size = cfg.hidden_size;
4436 let i_size = cfg.intermediate_size;
4437 let gate_proj = h_size * i_size / weight_pack_factor;
4438 let up_proj = h_size * i_size / weight_pack_factor;
4439 let down_proj = i_size * h_size / weight_pack_factor;
4440
4441 input_layernorm
4442 + post_attention_layernorm
4443 + q_proj
4444 + k_proj
4445 + v_proj
4446 + o_proj
4447 + gate_proj
4448 + up_proj
4449 + down_proj
4450 };
4451 Ok(vec![
4452 per_layer_elems * dtype.size_in_bytes();
4453 cfg.num_hidden_layers
4454 ])
4455 }
4456
4457 fn num_layers(&self, config: &str) -> Result<usize> {
4458 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4459
4460 Ok(cfg.num_hidden_layers)
4461 }
4462 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4463 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4464
4465 let cfg = ModelConfigMetadata {
4466 max_seq_len: cfg.max_position_embeddings,
4467 num_layers: cfg.num_hidden_layers,
4468 hidden_size: cfg.hidden_size,
4469 num_kv_heads: cfg.num_key_value_heads,
4470 num_attn_heads: cfg.num_attention_heads,
4471 sliding_window: None,
4472 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4473 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4474 };
4475
4476 Ok(Box::new(cfg))
4477 }
4478}
4479
4480pub struct GraniteMoeHybridLoader;
4486
4487impl NormalModelLoader for GraniteMoeHybridLoader {
4488 fn load(
4489 &self,
4490 config: &str,
4491 vb: ShardedVarBuilder,
4492 normal_loading_metadata: NormalLoadingMetadata,
4493 attention_mechanism: AttentionImplementation,
4494 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4495 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4496
4497 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4498 &cfg,
4499 vb,
4500 self.is_gptx(config)?,
4501 normal_loading_metadata,
4502 attention_mechanism,
4503 )?))
4504 }
4505 fn load_xlora(
4506 &self,
4507 _config: &str,
4508 _vb: ShardedVarBuilder,
4509 _lora_config: &[((String, String), LoraConfig)],
4510 _xlora_config: Option<XLoraConfig>,
4511 _xlora_ordering: Ordering,
4512 _normal_loading_metadata: NormalLoadingMetadata,
4513 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4514 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4515 todo!()
4516 }
4517 fn is_gptx(&self, _: &str) -> Result<bool> {
4518 Ok(true)
4519 }
4520 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4521 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4522 Ok(Box::new(cfg))
4523 }
4524}
4525
4526impl IsqModelLoader for GraniteMoeHybridLoader {
4527 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4528 Ok(vec![
4529 Regex::new(r"lm_head\.(weight|bias)$")?,
4530 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4532 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4533 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4534 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4535 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4537 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4538 ])
4539 }
4540 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4541 self.isq_layer_regexes(config)
4542 }
4543}
4544
4545impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4546 fn mapped_max_act_size_elems(
4547 &self,
4548 config: &str,
4549 params: &AutoDeviceMapParams,
4550 ) -> Result<usize> {
4551 let AutoDeviceMapParams::Text {
4552 max_seq_len,
4553 max_batch_size,
4554 } = params
4555 else {
4556 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4557 };
4558
4559 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4560
4561 Ok(
4562 max_batch_size
4563 * cfg.num_attention_heads
4564 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4565 )
4566 }
4567 fn non_mapped_max_act_size_elems(
4568 &self,
4569 _config: &str,
4570 _params: &AutoDeviceMapParams,
4571 ) -> Result<usize> {
4572 Ok(0)
4573 }
4574
4575 fn non_mapped_size_in_bytes(
4576 &self,
4577 config: &str,
4578 dtype: DType,
4579 weight_pack_factor: usize,
4580 _matformer_config: Option<&MatformerSliceConfig>,
4581 ) -> Result<usize> {
4582 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4583
4584 let elems = {
4585 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4586 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4588 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4589 } else {
4590 0
4591 };
4592 let norm = cfg.hidden_size;
4593 embed_tokens + lm_head + norm
4594 };
4595 Ok(elems * dtype.size_in_bytes())
4596 }
4597
4598 fn layer_sizes_in_bytes(
4599 &self,
4600 config: &str,
4601 dtype: DType,
4602 weight_pack_factor: usize,
4603 _matformer_config: Option<&MatformerSliceConfig>,
4604 ) -> Result<Vec<usize>> {
4605 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4606
4607 let per_layer_elems = {
4608 let input_layernorm = cfg.hidden_size;
4609 let post_attention_layernorm = cfg.hidden_size;
4610
4611 let size_in = cfg.hidden_size;
4612 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4613 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4614 let q_proj = size_in * size_q / weight_pack_factor;
4615 let k_proj = size_in * size_kv / weight_pack_factor;
4616 let v_proj = size_in * size_kv / weight_pack_factor;
4617 let o_proj = size_q * size_in / weight_pack_factor;
4618
4619 let h_size = cfg.hidden_size;
4620 let shared_i_size = cfg.shared_intermediate_size();
4621 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4623 let output_linear = shared_i_size * h_size / weight_pack_factor;
4624
4625 input_layernorm
4626 + post_attention_layernorm
4627 + q_proj
4628 + k_proj
4629 + v_proj
4630 + o_proj
4631 + input_linear
4632 + output_linear
4633 };
4634 Ok(vec![
4635 per_layer_elems * dtype.size_in_bytes();
4636 cfg.num_hidden_layers
4637 ])
4638 }
4639
4640 fn num_layers(&self, config: &str) -> Result<usize> {
4641 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4642
4643 Ok(cfg.num_hidden_layers)
4644 }
4645 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4646 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4647
4648 let cfg = ModelConfigMetadata {
4649 max_seq_len: cfg.max_position_embeddings,
4650 num_layers: cfg.num_hidden_layers,
4651 hidden_size: cfg.hidden_size,
4652 num_kv_heads: cfg.num_key_value_heads(),
4653 num_attn_heads: cfg.num_attention_heads,
4654 sliding_window: None,
4655 k_head_dim: cfg.head_dim(),
4656 v_head_dim: cfg.head_dim(),
4657 };
4658
4659 Ok(Box::new(cfg))
4660 }
4661}
4662
4663pub struct GptOssLoader;
4669
4670impl NormalModelLoader for GptOssLoader {
4671 fn load(
4672 &self,
4673 config: &str,
4674 vb: ShardedVarBuilder,
4675 normal_loading_metadata: NormalLoadingMetadata,
4676 attention_mechanism: AttentionImplementation,
4677 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4678 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4679
4680 Ok(Box::new(models::gpt_oss::Model::new(
4681 &cfg,
4682 vb,
4683 self.is_gptx(config)?,
4684 normal_loading_metadata,
4685 attention_mechanism,
4686 )?))
4687 }
4688 fn load_xlora(
4689 &self,
4690 _config: &str,
4691 _vb: ShardedVarBuilder,
4692 _lora_config: &[((String, String), LoraConfig)],
4693 _xlora_config: Option<XLoraConfig>,
4694 _xlora_ordering: Ordering,
4695 _normal_loading_metadata: NormalLoadingMetadata,
4696 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4697 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4698 anyhow::bail!("GPT-OSS does not support X-LoRA")
4699 }
4700 fn is_gptx(&self, _: &str) -> Result<bool> {
4701 Ok(true)
4702 }
4703 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4704 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4705 Ok(Box::new(cfg))
4706 }
4707 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4708 Ok(false)
4709 }
4710}
4711
4712impl IsqModelLoader for GptOssLoader {
4713 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4714 Ok(vec![
4716 Regex::new(r"lm_head\.(weight|bias)$")?,
4717 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4719 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4720 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4721 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4722 ])
4723 }
4724 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4725 self.isq_layer_regexes(config)
4726 }
4727}
4728
4729impl DeviceMappedModelLoader for GptOssLoader {
4730 fn mapped_max_act_size_elems(
4731 &self,
4732 config: &str,
4733 params: &AutoDeviceMapParams,
4734 ) -> Result<usize> {
4735 let AutoDeviceMapParams::Text {
4736 max_seq_len,
4737 max_batch_size,
4738 } = params
4739 else {
4740 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4741 };
4742
4743 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4744
4745 Ok(
4746 max_batch_size
4747 * cfg.num_attention_heads
4748 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4749 )
4750 }
4751 fn non_mapped_max_act_size_elems(
4752 &self,
4753 _config: &str,
4754 _params: &AutoDeviceMapParams,
4755 ) -> Result<usize> {
4756 Ok(0)
4757 }
4758
4759 fn non_mapped_size_in_bytes(
4760 &self,
4761 config: &str,
4762 dtype: DType,
4763 weight_pack_factor: usize,
4764 _matformer_config: Option<&MatformerSliceConfig>,
4765 ) -> Result<usize> {
4766 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4767
4768 let elems = {
4769 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4770 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4771 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4772 } else {
4773 0
4774 };
4775 let norm = cfg.hidden_size;
4776 embed_tokens + lm_head + norm
4777 };
4778 Ok(elems * dtype.size_in_bytes())
4779 }
4780
4781 fn layer_sizes_in_bytes(
4782 &self,
4783 config: &str,
4784 dtype: DType,
4785 weight_pack_factor: usize,
4786 _matformer_config: Option<&MatformerSliceConfig>,
4787 ) -> Result<Vec<usize>> {
4788 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4789
4790 let per_layer_elems = {
4791 let input_layernorm = cfg.hidden_size;
4792 let post_attention_layernorm = cfg.hidden_size;
4793
4794 let size_in = cfg.hidden_size;
4795 let head_dim = cfg.head_dim();
4796 let size_q = head_dim * cfg.num_attention_heads;
4797 let size_kv = head_dim * cfg.num_key_value_heads;
4798 let q_proj =
4799 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4800 let k_proj =
4801 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4802 let v_proj =
4803 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4804 let o_proj =
4805 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4806
4807 let mxfp4_pack = 2;
4812 let gate_up_proj_size =
4813 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4814 let down_proj_size =
4815 cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4816 let gate_up_scales =
4818 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4819 let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4820 let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4822 let down_bias = cfg.num_local_experts * cfg.hidden_size;
4823 let router = cfg.hidden_size * cfg.num_local_experts;
4825 let sinks = cfg.num_attention_heads;
4827
4828 input_layernorm
4829 + post_attention_layernorm
4830 + q_proj
4831 + k_proj
4832 + v_proj
4833 + o_proj
4834 + gate_up_proj_size
4835 + down_proj_size
4836 + gate_up_scales
4837 + down_scales
4838 + gate_up_bias
4839 + down_bias
4840 + router
4841 + sinks
4842 };
4843 Ok(vec![
4844 per_layer_elems * dtype.size_in_bytes();
4845 cfg.num_hidden_layers
4846 ])
4847 }
4848
4849 fn num_layers(&self, config: &str) -> Result<usize> {
4850 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4851
4852 Ok(cfg.num_hidden_layers)
4853 }
4854 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4855 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4856
4857 let head_dim = cfg.head_dim();
4858 let cfg = ModelConfigMetadata {
4859 max_seq_len: cfg.max_position_embeddings,
4860 num_layers: cfg.num_hidden_layers,
4861 hidden_size: cfg.hidden_size,
4862 num_kv_heads: cfg.num_key_value_heads,
4863 num_attn_heads: cfg.num_attention_heads,
4864 sliding_window: cfg.sliding_window,
4865 k_head_dim: head_dim,
4866 v_head_dim: head_dim,
4867 };
4868
4869 Ok(Box::new(cfg))
4870 }
4871}