1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 amoe::AnyMoeBaseModelMixin,
10 device_map::DeviceMapper,
11 lora::{LoraConfig, Ordering},
12 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
13 pipeline::{
14 isq::IsqModelLoader,
15 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16 EitherCache, IsqModel,
17 },
18 utils::varbuilder_utils::DeviceForLoadTensor,
19 xlora_models::NonGranularState,
20};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use mistralrs_quant::log::once_log_info;
24
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27#[cfg(feature = "pyo3_macros")]
28use pyo3::pyclass;
29
30use regex::Regex;
31use serde::Deserialize;
32
33use crate::{
34 models,
35 xlora_models::{self, XLoraConfig},
36};
37
38use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
39
40pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
41 #[allow(clippy::too_many_arguments)]
42 fn forward(
43 &self,
44 input_ids: &Tensor,
45 seqlen_offsets: &[usize],
46 context_lens: Vec<(usize, usize)>,
47 position_ids: Vec<usize>,
48 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
49 flash_params: &FlashParams,
50 ) -> candle_core::Result<Tensor>;
51 #[allow(clippy::too_many_arguments)]
52 fn xlora_forward(
53 &self,
54 input_ids: &Tensor,
55 input_ids_full: &Tensor,
56 seqlen_offsets: &[usize],
57 seqlen_offsets_full: &[usize],
58 no_kv_cache: bool,
59 non_granular_state: &Option<NonGranularState>,
60 context_lens: Vec<(usize, usize)>,
61 position_ids: Vec<usize>,
62 flash_params: &FlashParams,
63 flash_params_full: &FlashParams,
64 ) -> candle_core::Result<Tensor>;
65 fn is_xlora(&self) -> bool;
66 fn device(&self) -> &Device;
67 fn cache(&self) -> &EitherCache;
68 fn cache_mut(&mut self) -> &mut EitherCache;
69 fn max_seq_len(&self) -> usize;
70 fn config(&self) -> &ModelConfigMetadata;
71}
72
73pub struct NormalLoadingMetadata {
75 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
77 pub loading_isq: bool,
79 pub real_device: Device,
81 pub multi_progress: Arc<MultiProgress>,
83}
84
85pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
86 fn load(
87 &self,
88 config: &str,
89 vb: ShardedVarBuilder,
90 normal_loading_metadata: NormalLoadingMetadata,
91 attention_mechanism: AttentionImplementation,
92 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
93 #[allow(clippy::too_many_arguments)]
94 fn load_xlora(
95 &self,
96 config: &str,
97 vb: ShardedVarBuilder,
98 lora_config: &[((String, String), LoraConfig)],
99 xlora_config: Option<XLoraConfig>,
100 xlora_ordering: Ordering,
101 normal_loading_metadata: NormalLoadingMetadata,
102 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
103 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
104 fn is_gptx(&self, config: &str) -> Result<bool>;
105 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
106 Ok(true)
107 }
108 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
109 fn get_device_for_tensor(
110 &self,
111 config: &str,
112 _mapper: &dyn DeviceMapper,
113 loading_isq: bool,
114 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115 if loading_isq {
116 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117 } else {
118 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119 let num_layers = self.model_config(config)?.num_layers();
120 let closure = move |name: String| {
121 if let Some(captures) = re.captures(&name) {
122 captures
123 .get(1)
124 .and_then(|m| m.as_str().parse::<usize>().ok())
125 .map(|l| l.min(num_layers))
126 .map(DeviceForLoadTensor::Idx)
127 .unwrap_or(DeviceForLoadTensor::Base)
128 } else {
129 DeviceForLoadTensor::Base
130 }
131 };
132
133 Ok(Arc::new(closure))
134 }
135 }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140pub enum NormalLoaderType {
142 #[serde(rename = "mistral")]
143 Mistral,
144 #[serde(rename = "gemma")]
145 Gemma,
146 #[serde(rename = "mixtral")]
147 Mixtral,
148 #[serde(rename = "llama")]
149 Llama,
150 #[serde(rename = "phi2")]
151 Phi2,
152 #[serde(rename = "phi3")]
153 Phi3,
154 #[serde(rename = "qwen2")]
155 Qwen2,
156 #[serde(rename = "gemma2")]
157 Gemma2,
158 #[serde(rename = "starcoder2")]
159 Starcoder2,
160 #[serde(rename = "phi3.5moe")]
161 Phi3_5MoE,
162 #[serde(rename = "deepseekv2")]
163 DeepSeekV2,
164 #[serde(rename = "deepseekv3")]
165 DeepSeekV3,
166 #[serde(rename = "qwen3")]
167 Qwen3,
168 #[serde(rename = "glm4")]
169 GLM4,
170 #[serde(rename = "qwen3moe")]
171 Qwen3Moe,
172 #[serde(rename = "smollm3")]
173 SmolLm3,
174}
175
176impl NormalLoaderType {
178 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179 match name {
180 "MistralForCausalLM" => Ok(Self::Mistral),
181 "MixtralForCausalLM" => Ok(Self::Mixtral),
182 "GemmaForCausalLM" => Ok(Self::Gemma),
183 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
184 "PhiForCausalLM" => Ok(Self::Phi2),
185 "Phi3ForCausalLM" => Ok(Self::Phi3),
186 "LlamaForCausalLM" => Ok(Self::Llama),
187 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
188 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
189 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
190 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
191 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
192 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
193 "Glm4ForCausalLM" => Ok(Self::GLM4),
194 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
195 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
196 other => anyhow::bail!(
197 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
198 ),
199 }
200 }
201}
202
203impl FromStr for NormalLoaderType {
204 type Err = String;
205 fn from_str(s: &str) -> Result<Self, Self::Err> {
206 match s {
207 "mistral" => Ok(Self::Mistral),
208 "gemma" => Ok(Self::Gemma),
209 "mixtral" => Ok(Self::Mixtral),
210 "llama" => Ok(Self::Llama),
211 "phi2" => Ok(Self::Phi2),
212 "phi3" => Ok(Self::Phi3),
213 "qwen2" => Ok(Self::Qwen2),
214 "gemma2" => Ok(Self::Gemma2),
215 "starcoder2" => Ok(Self::Starcoder2),
216 "phi3.5moe" => Ok(Self::Phi3_5MoE),
217 "deepseekv2" => Ok(Self::DeepSeekV2),
218 "deepseekv3" => Ok(Self::DeepSeekV3),
219 "qwen3" => Ok(Self::Qwen3),
220 "glm4" => Ok(Self::GLM4),
221 "qwen3moe" => Ok(Self::Qwen3Moe),
222 "smollm3" => Ok(Self::SmolLm3),
223 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
224 }
225 }
226}
227
228impl Display for NormalLoaderType {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 match self {
231 Self::Gemma => write!(f, "gemma"),
232 Self::Gemma2 => write!(f, "gemma2"),
233 Self::Llama => write!(f, "llama"),
234 Self::Mistral => write!(f, "mistral"),
235 Self::Mixtral => write!(f, "mixtral"),
236 Self::Phi2 => write!(f, "phi2"),
237 Self::Phi3 => write!(f, "phi3"),
238 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
239 Self::Qwen2 => write!(f, "qwen2"),
240 Self::Starcoder2 => write!(f, "starcoder2"),
241 Self::DeepSeekV2 => write!(f, "deepseekv2"),
242 Self::DeepSeekV3 => write!(f, "deepseekv3"),
243 Self::Qwen3 => write!(f, "qwen3"),
244 Self::GLM4 => write!(f, "glm4"),
245 Self::Qwen3Moe => write!(f, "qwen3moe"),
246 Self::SmolLm3 => write!(f, "smollm3"),
247 }
248 }
249}
250
251macro_rules! bias_if {
252 ($cond:expr, $size:expr) => {
253 if $cond {
254 $size
255 } else {
256 0
257 }
258 };
259}
260
261pub struct AutoNormalLoader;
263
264#[derive(Deserialize)]
265struct AutoNormalLoaderConfig {
266 architectures: Vec<String>,
267}
268
269impl AutoNormalLoader {
270 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
271 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
272 if auto_cfg.architectures.len() != 1 {
273 anyhow::bail!("Expected to have one name for `architectures` config field.")
274 }
275
276 let name = &auto_cfg.architectures[0];
277
278 let tp = NormalLoaderType::from_causal_lm_name(name)?;
279
280 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
281
282 match tp {
283 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
284 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
285 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
286 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
287 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
288 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
289 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
290 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
291 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
292 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
293 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
294 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
295 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
296 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
297 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
298 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
299 }
300 }
301}
302
303impl NormalModelLoader for AutoNormalLoader {
304 fn load(
305 &self,
306 config: &str,
307 vb: ShardedVarBuilder,
308 normal_loading_metadata: NormalLoadingMetadata,
309 attention_mechanism: AttentionImplementation,
310 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
311 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312 }
313 fn load_xlora(
314 &self,
315 config: &str,
316 vb: ShardedVarBuilder,
317 lora_config: &[((String, String), LoraConfig)],
318 xlora_config: Option<XLoraConfig>,
319 xlora_ordering: Ordering,
320 normal_loading_metadata: NormalLoadingMetadata,
321 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
322 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
323 Self::get_loader(config)?.load_xlora(
324 config,
325 vb,
326 lora_config,
327 xlora_config,
328 xlora_ordering,
329 normal_loading_metadata,
330 preload_adapters,
331 )
332 }
333 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
334 Self::get_loader(config)?.get_config_repr(config)
335 }
336 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
337 Self::get_loader(config)?.supports_paged_attention(config)
338 }
339 fn is_gptx(&self, config: &str) -> Result<bool> {
340 Self::get_loader(config)?.is_gptx(config)
341 }
342}
343
344impl IsqModelLoader for AutoNormalLoader {
345 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
346 Self::get_loader(config)?.immediate_isq_predicates(config)
347 }
348 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
349 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
350 }
351 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
352 Self::get_loader(config)?.isq_layer_regexes(config)
353 }
354 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
355 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
356 }
357}
358
359impl DeviceMappedModelLoader for AutoNormalLoader {
360 fn non_mapped_size_in_bytes(
361 &self,
362 config: &str,
363 dtype: DType,
364 weight_pack_factor: usize,
365 ) -> Result<usize> {
366 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
367 }
368 fn num_layers(&self, config: &str) -> Result<usize> {
369 Self::get_loader(config)?.num_layers(config)
370 }
371 fn layer_sizes_in_bytes(
372 &self,
373 config: &str,
374 dtype: DType,
375 weight_pack_factor: usize,
376 ) -> Result<Vec<usize>> {
377 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
378 }
379 fn mapped_max_act_size_elems(
380 &self,
381 config: &str,
382 params: &super::AutoDeviceMapParams,
383 prompt_chunksize: usize,
384 ) -> Result<usize> {
385 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
386 }
387 fn non_mapped_max_act_size_elems(
388 &self,
389 _config: &str,
390 _params: &AutoDeviceMapParams,
391 ) -> Result<usize> {
392 Ok(0)
393 }
394 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
395 Self::get_loader(config)?.model_config(config)
396 }
397}
398
399pub struct MistralLoader;
402
403impl NormalModelLoader for MistralLoader {
404 fn load(
405 &self,
406 config: &str,
407 vb: ShardedVarBuilder,
408 normal_loading_metadata: NormalLoadingMetadata,
409 attention_mechanism: AttentionImplementation,
410 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
411 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
412 Ok(Box::new(models::mistral::Model::new(
413 &cfg,
414 vb,
415 self.is_gptx(config)?,
416 normal_loading_metadata,
417 attention_mechanism,
418 )?))
419 }
420 fn load_xlora(
421 &self,
422 config: &str,
423 vb: ShardedVarBuilder,
424 lora_config: &[((String, String), LoraConfig)],
425 xlora_config: Option<XLoraConfig>,
426 xlora_ordering: Ordering,
427 normal_loading_metadata: NormalLoadingMetadata,
428 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
429 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
430 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
431 Ok(Box::new(xlora_models::XLoraMistral::new(
432 &cfg,
433 vb,
434 lora_config,
435 xlora_config,
436 xlora_ordering,
437 self.is_gptx(config)?,
438 normal_loading_metadata,
439 preload_adapters,
440 )?))
441 }
442 fn is_gptx(&self, _: &str) -> Result<bool> {
443 Ok(true)
444 }
445 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
446 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
447 Ok(Box::new(cfg))
448 }
449}
450
451impl IsqModelLoader for MistralLoader {
452 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
453 Ok(vec![
454 Regex::new(r"lm_head\.(weight|bias)$")?,
455 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
457 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
458 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
459 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
460 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
462 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
463 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
464 ])
465 }
466 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
467 self.isq_layer_regexes(config)
468 }
469}
470
471impl DeviceMappedModelLoader for MistralLoader {
472 fn mapped_max_act_size_elems(
473 &self,
474 config: &str,
475 params: &AutoDeviceMapParams,
476 prompt_chunksize: usize,
477 ) -> Result<usize> {
478 let AutoDeviceMapParams::Text {
479 max_seq_len: _,
480 max_batch_size,
481 } = params
482 else {
483 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
484 };
485
486 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
487
488 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
489 }
490 fn non_mapped_max_act_size_elems(
491 &self,
492 _config: &str,
493 _params: &AutoDeviceMapParams,
494 ) -> Result<usize> {
495 Ok(0)
496 }
497
498 fn non_mapped_size_in_bytes(
499 &self,
500 config: &str,
501 dtype: DType,
502 weight_pack_factor: usize,
503 ) -> Result<usize> {
504 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
505
506 let elems = {
507 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
508 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
510 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
511 } else {
512 0
513 };
514 let norm = cfg.hidden_size;
515 embed_tokens + lm_head + norm
516 };
517 Ok(elems * dtype.size_in_bytes())
518 }
519
520 fn layer_sizes_in_bytes(
521 &self,
522 config: &str,
523 dtype: DType,
524 weight_pack_factor: usize,
525 ) -> Result<Vec<usize>> {
526 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
527
528 let per_layer_elems = {
529 let input_layernorm = cfg.hidden_size;
530 let post_attention_layernorm = cfg.hidden_size;
531
532 let size_in = cfg.hidden_size;
533 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
534 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
535 let q_proj = size_in * size_q / weight_pack_factor;
536 let k_proj = size_in * size_kv / weight_pack_factor;
537 let v_proj = size_in * size_kv / weight_pack_factor;
538 let o_proj = size_q * size_in / weight_pack_factor;
539
540 let h_size = cfg.hidden_size;
541 let i_size = cfg.intermediate_size;
542 let gate_proj = h_size * i_size / weight_pack_factor;
543 let up_proj = h_size * i_size / weight_pack_factor;
544 let down_proj = i_size * h_size / weight_pack_factor;
545
546 input_layernorm
547 + post_attention_layernorm
548 + q_proj
549 + k_proj
550 + v_proj
551 + o_proj
552 + gate_proj
553 + up_proj
554 + down_proj
555 };
556 Ok(vec![
557 per_layer_elems * dtype.size_in_bytes();
558 cfg.num_hidden_layers
559 ])
560 }
561
562 fn num_layers(&self, config: &str) -> Result<usize> {
563 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
564 Ok(cfg.num_hidden_layers)
565 }
566
567 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
568 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
569
570 let cfg = ModelConfigMetadata {
571 max_seq_len: cfg.max_position_embeddings,
572 num_layers: cfg.num_hidden_layers,
573 hidden_size: cfg.hidden_size,
574 num_kv_heads: cfg.num_key_value_heads,
575 num_attn_heads: cfg.num_attention_heads,
576 sliding_window: cfg.sliding_window,
577 k_head_dim: cfg.head_dim(),
578 v_head_dim: cfg.head_dim(),
579 };
580
581 Ok(Box::new(cfg))
582 }
583}
584
585pub struct GemmaLoader;
591
592impl NormalModelLoader for GemmaLoader {
593 fn load(
594 &self,
595 config: &str,
596 vb: ShardedVarBuilder,
597 normal_loading_metadata: NormalLoadingMetadata,
598 attention_mechanism: AttentionImplementation,
599 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
600 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
601
602 Ok(Box::new(models::gemma::Model::new(
603 &cfg,
604 vb,
605 self.is_gptx(config)?,
606 normal_loading_metadata,
607 attention_mechanism,
608 )?))
609 }
610 fn load_xlora(
611 &self,
612 config: &str,
613 vb: ShardedVarBuilder,
614 lora_config: &[((String, String), LoraConfig)],
615 xlora_config: Option<XLoraConfig>,
616 xlora_ordering: Ordering,
617 normal_loading_metadata: NormalLoadingMetadata,
618 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
619 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
620 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
621
622 Ok(Box::new(xlora_models::XLoraGemma::new(
623 &cfg,
624 vb,
625 lora_config,
626 xlora_config,
627 xlora_ordering,
628 self.is_gptx(config)?,
629 normal_loading_metadata,
630 preload_adapters,
631 )?))
632 }
633 fn is_gptx(&self, _: &str) -> Result<bool> {
634 Ok(true)
635 }
636 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
637 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
638 Ok(Box::new(cfg))
639 }
640}
641
642impl IsqModelLoader for GemmaLoader {
643 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
644 Ok(vec![
645 Regex::new(r"lm_head\.(weight|bias)$")?,
646 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
648 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
649 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
650 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
651 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
653 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
654 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
655 ])
656 }
657 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
658 self.isq_layer_regexes(config)
659 }
660}
661
662impl DeviceMappedModelLoader for GemmaLoader {
663 fn mapped_max_act_size_elems(
664 &self,
665 config: &str,
666 params: &AutoDeviceMapParams,
667 prompt_chunksize: usize,
668 ) -> Result<usize> {
669 let AutoDeviceMapParams::Text {
670 max_seq_len: _,
671 max_batch_size,
672 } = params
673 else {
674 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
675 };
676
677 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
678
679 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
680 }
681 fn non_mapped_max_act_size_elems(
682 &self,
683 _config: &str,
684 _params: &AutoDeviceMapParams,
685 ) -> Result<usize> {
686 Ok(0)
687 }
688
689 fn non_mapped_size_in_bytes(
690 &self,
691 config: &str,
692 dtype: DType,
693 weight_pack_factor: usize,
694 ) -> Result<usize> {
695 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
696
697 let elems = {
698 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
699 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
701 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
702 } else {
703 0
704 };
705 let norm = cfg.hidden_size;
706 embed_tokens + lm_head + norm
707 };
708 Ok(elems * dtype.size_in_bytes())
709 }
710
711 fn layer_sizes_in_bytes(
712 &self,
713 config: &str,
714 dtype: DType,
715 weight_pack_factor: usize,
716 ) -> Result<Vec<usize>> {
717 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
718
719 let per_layer_elems = {
720 let input_layernorm = cfg.hidden_size;
721 let post_attention_layernorm = cfg.hidden_size;
722
723 let size_in = cfg.hidden_size;
724 let size_q = cfg.head_dim * cfg.num_attention_heads;
725 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
726 let q_proj =
727 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
728 let k_proj =
729 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
730 let v_proj =
731 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
732 let o_proj =
733 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
734
735 let h_size = cfg.hidden_size;
736 let i_size = cfg.intermediate_size;
737 let gate_proj = h_size * i_size / weight_pack_factor;
738 let up_proj = h_size * i_size / weight_pack_factor;
739 let down_proj = i_size * h_size / weight_pack_factor;
740
741 input_layernorm
742 + post_attention_layernorm
743 + q_proj
744 + k_proj
745 + v_proj
746 + o_proj
747 + gate_proj
748 + up_proj
749 + down_proj
750 };
751 Ok(vec![
752 per_layer_elems * dtype.size_in_bytes();
753 cfg.num_hidden_layers
754 ])
755 }
756
757 fn num_layers(&self, config: &str) -> Result<usize> {
758 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
759 Ok(cfg.num_hidden_layers)
760 }
761
762 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
763 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
764
765 let cfg = ModelConfigMetadata {
766 max_seq_len: cfg.max_position_embeddings,
767 num_layers: cfg.num_hidden_layers,
768 hidden_size: cfg.hidden_size,
769 num_kv_heads: cfg.num_key_value_heads,
770 num_attn_heads: cfg.num_attention_heads,
771 sliding_window: None,
772 k_head_dim: cfg.head_dim,
773 v_head_dim: cfg.head_dim,
774 };
775
776 Ok(Box::new(cfg))
777 }
778}
779
780pub struct LlamaLoader;
786
787impl NormalModelLoader for LlamaLoader {
788 fn load(
789 &self,
790 config: &str,
791 vb: ShardedVarBuilder,
792 normal_loading_metadata: NormalLoadingMetadata,
793 attention_mechanism: AttentionImplementation,
794 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
795 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
796
797 Ok(Box::new(models::llama::Llama::new(
798 &cfg,
799 vb,
800 self.is_gptx(config)?,
801 normal_loading_metadata,
802 attention_mechanism,
803 )?))
804 }
805 fn load_xlora(
806 &self,
807 config: &str,
808 vb: ShardedVarBuilder,
809 lora_config: &[((String, String), LoraConfig)],
810 xlora_config: Option<XLoraConfig>,
811 xlora_ordering: Ordering,
812 normal_loading_metadata: NormalLoadingMetadata,
813 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
814 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
815 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
816
817 Ok(Box::new(xlora_models::XLoraLlama::new(
818 &cfg,
819 vb,
820 lora_config,
821 xlora_config,
822 xlora_ordering,
823 self.is_gptx(config)?,
824 normal_loading_metadata,
825 preload_adapters,
826 )?))
827 }
828 fn is_gptx(&self, _: &str) -> Result<bool> {
829 Ok(true)
830 }
831 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
832 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
833 Ok(Box::new(cfg))
834 }
835}
836
837impl IsqModelLoader for LlamaLoader {
838 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
839 Ok(vec![
840 Regex::new(r"lm_head\.(weight|bias)$")?,
841 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
843 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
844 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
845 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
846 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
848 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
849 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
850 ])
851 }
852 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
853 self.isq_layer_regexes(config)
854 }
855}
856
857impl DeviceMappedModelLoader for LlamaLoader {
858 fn mapped_max_act_size_elems(
859 &self,
860 config: &str,
861 params: &AutoDeviceMapParams,
862 prompt_chunksize: usize,
863 ) -> Result<usize> {
864 let AutoDeviceMapParams::Text {
865 max_seq_len: _,
866 max_batch_size,
867 } = params
868 else {
869 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
870 };
871
872 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
873
874 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
875 }
876 fn non_mapped_max_act_size_elems(
877 &self,
878 _config: &str,
879 _params: &AutoDeviceMapParams,
880 ) -> Result<usize> {
881 Ok(0)
882 }
883
884 fn non_mapped_size_in_bytes(
885 &self,
886 config: &str,
887 dtype: DType,
888 weight_pack_factor: usize,
889 ) -> Result<usize> {
890 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
891
892 let elems = {
893 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
894 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
896 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
897 } else {
898 0
899 };
900 let norm = cfg.hidden_size;
901 embed_tokens + lm_head + norm
902 };
903 Ok(elems * dtype.size_in_bytes())
904 }
905
906 fn layer_sizes_in_bytes(
907 &self,
908 config: &str,
909 dtype: DType,
910 weight_pack_factor: usize,
911 ) -> Result<Vec<usize>> {
912 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
913
914 let per_layer_elems = {
915 let input_layernorm = cfg.hidden_size;
916 let post_attention_layernorm = cfg.hidden_size;
917
918 let size_in = cfg.hidden_size;
919 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
920 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
921 let q_proj = size_in * size_q / weight_pack_factor;
922 let k_proj = size_in * size_kv / weight_pack_factor;
923 let v_proj = size_in * size_kv / weight_pack_factor;
924 let o_proj = size_q * size_in / weight_pack_factor;
925
926 let h_size = cfg.hidden_size;
927 let i_size = cfg.intermediate_size;
928 let gate_proj = h_size * i_size / weight_pack_factor;
929 let up_proj = h_size * i_size / weight_pack_factor;
930 let down_proj = i_size * h_size / weight_pack_factor;
931
932 input_layernorm
933 + post_attention_layernorm
934 + q_proj
935 + k_proj
936 + v_proj
937 + o_proj
938 + gate_proj
939 + up_proj
940 + down_proj
941 };
942 Ok(vec![
943 per_layer_elems * dtype.size_in_bytes();
944 cfg.num_hidden_layers
945 ])
946 }
947
948 fn num_layers(&self, config: &str) -> Result<usize> {
949 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
950
951 Ok(cfg.num_hidden_layers)
952 }
953 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
954 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956 let cfg = ModelConfigMetadata {
957 max_seq_len: cfg.max_position_embeddings,
958 num_layers: cfg.num_hidden_layers,
959 hidden_size: cfg.hidden_size,
960 num_kv_heads: cfg.num_key_value_heads,
961 num_attn_heads: cfg.num_attention_heads,
962 sliding_window: None,
963 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
964 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
965 };
966
967 Ok(Box::new(cfg))
968 }
969}
970
971pub struct MixtralLoader;
974
975impl NormalModelLoader for MixtralLoader {
976 fn load(
977 &self,
978 config: &str,
979 vb: ShardedVarBuilder,
980 normal_loading_metadata: NormalLoadingMetadata,
981 attention_mechanism: AttentionImplementation,
982 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
983 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
984
985 Ok(Box::new(models::mixtral::Model::new(
986 &cfg,
987 vb,
988 self.is_gptx(config)?,
989 normal_loading_metadata,
990 attention_mechanism,
991 )?))
992 }
993 fn load_xlora(
994 &self,
995 config: &str,
996 vb: ShardedVarBuilder,
997 lora_config: &[((String, String), LoraConfig)],
998 xlora_config: Option<XLoraConfig>,
999 xlora_ordering: Ordering,
1000 normal_loading_metadata: NormalLoadingMetadata,
1001 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1002 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1003 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1004
1005 Ok(Box::new(xlora_models::XLoraMixtral::new(
1006 &cfg,
1007 vb,
1008 lora_config,
1009 xlora_config,
1010 xlora_ordering,
1011 self.is_gptx(config)?,
1012 normal_loading_metadata,
1013 preload_adapters,
1014 )?))
1015 }
1016 fn is_gptx(&self, _: &str) -> Result<bool> {
1017 Ok(true)
1018 }
1019 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1020 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1021
1022 Ok(Box::new(cfg))
1023 }
1024}
1025
1026impl IsqModelLoader for MixtralLoader {
1027 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1028 Ok(vec![
1029 Regex::new(r"lm_head\.(weight|bias)$")?,
1030 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1032 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1033 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1034 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1035 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1037 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1038 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1039 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1040 ])
1041 }
1042 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1043 self.isq_layer_regexes(config)
1044 }
1045}
1046
1047impl DeviceMappedModelLoader for MixtralLoader {
1048 fn mapped_max_act_size_elems(
1049 &self,
1050 config: &str,
1051 params: &AutoDeviceMapParams,
1052 prompt_chunksize: usize,
1053 ) -> Result<usize> {
1054 let AutoDeviceMapParams::Text {
1055 max_seq_len: _,
1056 max_batch_size,
1057 } = params
1058 else {
1059 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1060 };
1061
1062 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1063
1064 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1065 }
1066 fn non_mapped_max_act_size_elems(
1067 &self,
1068 _config: &str,
1069 _params: &AutoDeviceMapParams,
1070 ) -> Result<usize> {
1071 Ok(0)
1072 }
1073
1074 fn non_mapped_size_in_bytes(
1075 &self,
1076 config: &str,
1077 dtype: DType,
1078 weight_pack_factor: usize,
1079 ) -> Result<usize> {
1080 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1081
1082 let elems = {
1083 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1084 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1086 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1087 } else {
1088 0
1089 };
1090 let norm = cfg.hidden_size;
1091 embed_tokens + lm_head + norm
1092 };
1093 Ok(elems * dtype.size_in_bytes())
1094 }
1095
1096 fn layer_sizes_in_bytes(
1097 &self,
1098 config: &str,
1099 dtype: DType,
1100 weight_pack_factor: usize,
1101 ) -> Result<Vec<usize>> {
1102 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1103
1104 let per_layer_elems = {
1105 let input_layernorm = cfg.hidden_size;
1106 let post_attention_layernorm = cfg.hidden_size;
1107
1108 let size_in = cfg.hidden_size;
1109 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1110 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1111 let q_proj = size_in * size_q / weight_pack_factor;
1112 let k_proj = size_in * size_kv / weight_pack_factor;
1113 let v_proj = size_in * size_kv / weight_pack_factor;
1114 let o_proj = size_q * size_in / weight_pack_factor;
1115
1116 let moe_block = {
1117 let gate = cfg.hidden_size * cfg.num_local_experts;
1118 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1120 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1121 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1122 gate + cfg.num_local_experts * w1
1123 + cfg.num_local_experts * w2
1124 + cfg.num_local_experts * w3
1125 };
1126
1127 input_layernorm
1128 + post_attention_layernorm
1129 + q_proj
1130 + k_proj
1131 + v_proj
1132 + o_proj
1133 + moe_block
1134 };
1135 Ok(vec![
1136 per_layer_elems * dtype.size_in_bytes();
1137 cfg.num_hidden_layers
1138 ])
1139 }
1140
1141 fn num_layers(&self, config: &str) -> Result<usize> {
1142 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1143
1144 Ok(cfg.num_hidden_layers)
1145 }
1146
1147 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1148 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1149
1150 let cfg = ModelConfigMetadata {
1151 max_seq_len: cfg.max_position_embeddings,
1152 num_layers: cfg.num_hidden_layers,
1153 hidden_size: cfg.hidden_size,
1154 num_kv_heads: cfg.num_key_value_heads,
1155 num_attn_heads: cfg.num_attention_heads,
1156 sliding_window: cfg.sliding_window,
1157 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1158 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1159 };
1160
1161 Ok(Box::new(cfg))
1162 }
1163}
1164
1165pub struct Phi2Loader;
1171
1172impl NormalModelLoader for Phi2Loader {
1173 fn load(
1174 &self,
1175 config: &str,
1176 vb: ShardedVarBuilder,
1177 normal_loading_metadata: NormalLoadingMetadata,
1178 attention_mechanism: AttentionImplementation,
1179 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1180 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1181
1182 Ok(Box::new(models::phi2::Model::new(
1183 &cfg,
1184 vb,
1185 self.is_gptx(config)?,
1186 normal_loading_metadata,
1187 attention_mechanism,
1188 )?))
1189 }
1190 fn load_xlora(
1191 &self,
1192 config: &str,
1193 vb: ShardedVarBuilder,
1194 lora_config: &[((String, String), LoraConfig)],
1195 xlora_config: Option<XLoraConfig>,
1196 xlora_ordering: Ordering,
1197 normal_loading_metadata: NormalLoadingMetadata,
1198 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1199 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1200 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1201
1202 Ok(Box::new(xlora_models::XLoraPhi2::new(
1203 &cfg,
1204 vb,
1205 lora_config,
1206 xlora_config,
1207 xlora_ordering,
1208 self.is_gptx(config)?,
1209 normal_loading_metadata,
1210 preload_adapters,
1211 )?))
1212 }
1213 fn is_gptx(&self, _: &str) -> Result<bool> {
1214 Ok(true)
1215 }
1216 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1217 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1218
1219 Ok(Box::new(cfg))
1220 }
1221}
1222
1223impl IsqModelLoader for Phi2Loader {
1224 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1225 Ok(vec![
1226 Regex::new(r"lm_head\.(weight|bias)$")?,
1227 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1229 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1230 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1231 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1232 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1234 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1235 ])
1236 }
1237 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1238 self.isq_layer_regexes(config)
1239 }
1240}
1241
1242impl DeviceMappedModelLoader for Phi2Loader {
1243 fn mapped_max_act_size_elems(
1244 &self,
1245 config: &str,
1246 params: &AutoDeviceMapParams,
1247 prompt_chunksize: usize,
1248 ) -> Result<usize> {
1249 let AutoDeviceMapParams::Text {
1250 max_seq_len: _,
1251 max_batch_size,
1252 } = params
1253 else {
1254 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1255 };
1256
1257 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1258
1259 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1260 }
1261 fn non_mapped_max_act_size_elems(
1262 &self,
1263 _config: &str,
1264 _params: &AutoDeviceMapParams,
1265 ) -> Result<usize> {
1266 Ok(0)
1267 }
1268
1269 fn non_mapped_size_in_bytes(
1270 &self,
1271 config: &str,
1272 dtype: DType,
1273 weight_pack_factor: usize,
1274 ) -> Result<usize> {
1275 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1276
1277 let elems = {
1278 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1279 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1281 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1282 } else {
1283 0
1284 };
1285 let norm = cfg.hidden_size;
1286 embed_tokens + lm_head + norm
1287 };
1288 Ok(elems * dtype.size_in_bytes())
1289 }
1290
1291 fn layer_sizes_in_bytes(
1292 &self,
1293 config: &str,
1294 dtype: DType,
1295 weight_pack_factor: usize,
1296 ) -> Result<Vec<usize>> {
1297 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1298
1299 let per_layer_elems = {
1300 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1301
1302 let size_in = cfg.hidden_size;
1303 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1304 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1305 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1306 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1307 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1308 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1309 let (q_norm, k_norm) = if cfg.qk_layernorm {
1310 (cfg.head_dim(), cfg.head_dim())
1311 } else {
1312 (0, 0)
1313 };
1314
1315 let h_size = cfg.hidden_size;
1316 let i_size = cfg.intermediate_size;
1317 let fc1 = h_size * i_size / weight_pack_factor;
1318 let fc2 = h_size * i_size / weight_pack_factor;
1319
1320 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1321 };
1322 Ok(vec![
1323 per_layer_elems * dtype.size_in_bytes();
1324 cfg.num_hidden_layers
1325 ])
1326 }
1327
1328 fn num_layers(&self, config: &str) -> Result<usize> {
1329 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1330
1331 Ok(cfg.num_hidden_layers)
1332 }
1333
1334 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1335 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1336
1337 let cfg = ModelConfigMetadata {
1338 max_seq_len: cfg.max_position_embeddings,
1339 num_layers: cfg.num_hidden_layers,
1340 hidden_size: cfg.hidden_size,
1341 num_kv_heads: cfg.num_key_value_heads(),
1342 num_attn_heads: cfg.num_attention_heads,
1343 sliding_window: None,
1344 k_head_dim: cfg.head_dim(),
1345 v_head_dim: cfg.head_dim(),
1346 };
1347
1348 Ok(Box::new(cfg))
1349 }
1350}
1351
1352pub struct Phi3Loader;
1358
1359impl NormalModelLoader for Phi3Loader {
1360 fn load(
1361 &self,
1362 config: &str,
1363 vb: ShardedVarBuilder,
1364 normal_loading_metadata: NormalLoadingMetadata,
1365 attention_mechanism: AttentionImplementation,
1366 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1367 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1368
1369 Ok(Box::new(models::phi3::Model::new(
1370 &cfg,
1371 vb,
1372 self.is_gptx(config)?,
1373 normal_loading_metadata,
1374 attention_mechanism,
1375 )?))
1376 }
1377 fn load_xlora(
1378 &self,
1379 config: &str,
1380 vb: ShardedVarBuilder,
1381 lora_config: &[((String, String), LoraConfig)],
1382 xlora_config: Option<XLoraConfig>,
1383 xlora_ordering: Ordering,
1384 normal_loading_metadata: NormalLoadingMetadata,
1385 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1386 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1387 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1388
1389 Ok(Box::new(xlora_models::XLoraPhi3::new(
1390 &cfg,
1391 vb,
1392 lora_config,
1393 xlora_config,
1394 xlora_ordering,
1395 self.is_gptx(config)?,
1396 normal_loading_metadata,
1397 preload_adapters,
1398 )?))
1399 }
1400 fn is_gptx(&self, _: &str) -> Result<bool> {
1401 Ok(true)
1402 }
1403 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1404 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1405
1406 Ok(Box::new(cfg))
1407 }
1408}
1409
1410impl IsqModelLoader for Phi3Loader {
1411 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1412 Ok(vec![
1413 Regex::new(r"lm_head\.(weight|bias)$")?,
1414 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1416 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1417 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1419 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1420 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1421 ])
1422 }
1423 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1424 self.isq_layer_regexes(config)
1425 }
1426}
1427
1428impl DeviceMappedModelLoader for Phi3Loader {
1429 fn mapped_max_act_size_elems(
1430 &self,
1431 config: &str,
1432 params: &AutoDeviceMapParams,
1433 prompt_chunksize: usize,
1434 ) -> Result<usize> {
1435 let AutoDeviceMapParams::Text {
1436 max_seq_len: _,
1437 max_batch_size,
1438 } = params
1439 else {
1440 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1441 };
1442
1443 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1444
1445 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1446 }
1447 fn non_mapped_max_act_size_elems(
1448 &self,
1449 _config: &str,
1450 _params: &AutoDeviceMapParams,
1451 ) -> Result<usize> {
1452 Ok(0)
1453 }
1454
1455 fn non_mapped_size_in_bytes(
1456 &self,
1457 config: &str,
1458 dtype: DType,
1459 weight_pack_factor: usize,
1460 ) -> Result<usize> {
1461 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1462
1463 let elems = {
1464 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1465 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1467 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1468 } else {
1469 0
1470 };
1471 let norm = cfg.hidden_size;
1472 embed_tokens + lm_head + norm
1473 };
1474 Ok(elems * dtype.size_in_bytes())
1475 }
1476
1477 fn layer_sizes_in_bytes(
1478 &self,
1479 config: &str,
1480 dtype: DType,
1481 weight_pack_factor: usize,
1482 ) -> Result<Vec<usize>> {
1483 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1484
1485 let per_layer_elems = {
1486 let input_layernorm = cfg.hidden_size;
1487 let post_attention_layernorm = cfg.hidden_size;
1488
1489 let size_in = cfg.hidden_size;
1490 let head_dim = cfg.head_dim();
1491 let op_size =
1492 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1493 let qkv_proj = size_in * op_size / weight_pack_factor;
1494 let o_proj =
1495 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1496
1497 let h_size = cfg.hidden_size;
1498 let i_size = cfg.intermediate_size;
1499 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1500 let down_proj = h_size * i_size / weight_pack_factor;
1501
1502 input_layernorm
1503 + post_attention_layernorm
1504 + qkv_proj
1505 + o_proj
1506 + gate_up_proj
1507 + down_proj
1508 };
1509 Ok(vec![
1510 per_layer_elems * dtype.size_in_bytes();
1511 cfg.num_hidden_layers
1512 ])
1513 }
1514
1515 fn num_layers(&self, config: &str) -> Result<usize> {
1516 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1517
1518 Ok(cfg.num_hidden_layers)
1519 }
1520
1521 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1522 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1523
1524 let cfg = ModelConfigMetadata {
1525 max_seq_len: cfg.max_position_embeddings,
1526 num_layers: cfg.num_hidden_layers,
1527 hidden_size: cfg.hidden_size,
1528 num_kv_heads: cfg.num_key_value_heads,
1529 num_attn_heads: cfg.num_attention_heads,
1530 sliding_window: cfg.sliding_window,
1531 k_head_dim: cfg.head_dim(),
1532 v_head_dim: cfg.head_dim(),
1533 };
1534
1535 Ok(Box::new(cfg))
1536 }
1537}
1538
1539pub struct Qwen2Loader;
1545
1546impl NormalModelLoader for Qwen2Loader {
1547 fn load(
1548 &self,
1549 config: &str,
1550 vb: ShardedVarBuilder,
1551 normal_loading_metadata: NormalLoadingMetadata,
1552 attention_mechanism: AttentionImplementation,
1553 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1554 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1555
1556 Ok(Box::new(models::qwen2::Model::new(
1557 &cfg,
1558 vb,
1559 self.is_gptx(config)?,
1560 normal_loading_metadata,
1561 attention_mechanism,
1562 )?))
1563 }
1564 fn load_xlora(
1565 &self,
1566 _config: &str,
1567 _vb: ShardedVarBuilder,
1568 _lora_config: &[((String, String), LoraConfig)],
1569 _xlora_config: Option<XLoraConfig>,
1570 _xlora_ordering: Ordering,
1571 _normal_loading_metadata: NormalLoadingMetadata,
1572 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1573 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1574 todo!()
1575 }
1576 fn is_gptx(&self, _: &str) -> Result<bool> {
1577 Ok(true)
1578 }
1579 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1580 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1581
1582 Ok(Box::new(cfg))
1583 }
1584}
1585
1586impl IsqModelLoader for Qwen2Loader {
1587 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1588 Ok(vec![
1589 Regex::new(r"lm_head\.(weight|bias)$")?,
1590 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1592 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1593 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1594 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1595 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1597 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1598 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1599 ])
1600 }
1601 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1602 self.isq_layer_regexes(config)
1603 }
1604}
1605
1606impl DeviceMappedModelLoader for Qwen2Loader {
1607 fn mapped_max_act_size_elems(
1608 &self,
1609 config: &str,
1610 params: &AutoDeviceMapParams,
1611 prompt_chunksize: usize,
1612 ) -> Result<usize> {
1613 let AutoDeviceMapParams::Text {
1614 max_seq_len: _,
1615 max_batch_size,
1616 } = params
1617 else {
1618 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1619 };
1620
1621 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1622
1623 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1624 }
1625 fn non_mapped_max_act_size_elems(
1626 &self,
1627 _config: &str,
1628 _params: &AutoDeviceMapParams,
1629 ) -> Result<usize> {
1630 Ok(0)
1631 }
1632
1633 fn non_mapped_size_in_bytes(
1634 &self,
1635 config: &str,
1636 dtype: DType,
1637 weight_pack_factor: usize,
1638 ) -> Result<usize> {
1639 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1640
1641 let elems = {
1642 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1643 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1645 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1646 } else {
1647 0
1648 };
1649 let norm = cfg.hidden_size;
1650 embed_tokens + lm_head + norm
1651 };
1652 Ok(elems * dtype.size_in_bytes())
1653 }
1654
1655 fn layer_sizes_in_bytes(
1656 &self,
1657 config: &str,
1658 dtype: DType,
1659 weight_pack_factor: usize,
1660 ) -> Result<Vec<usize>> {
1661 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1662
1663 let per_layer_elems = {
1664 let input_layernorm = cfg.hidden_size;
1665 let post_attention_layernorm = cfg.hidden_size;
1666
1667 let size_in = cfg.hidden_size;
1668 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1669 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1670 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1671 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1672 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1673 let o_proj = size_q * size_in / weight_pack_factor;
1674
1675 let h_size = cfg.hidden_size;
1676 let i_size = cfg.intermediate_size;
1677 let gate_proj = h_size * i_size / weight_pack_factor;
1678 let up_proj = h_size * i_size / weight_pack_factor;
1679 let down_proj = i_size * h_size / weight_pack_factor;
1680
1681 input_layernorm
1682 + post_attention_layernorm
1683 + q_proj
1684 + k_proj
1685 + v_proj
1686 + o_proj
1687 + gate_proj
1688 + up_proj
1689 + down_proj
1690 };
1691 Ok(vec![
1692 per_layer_elems * dtype.size_in_bytes();
1693 cfg.num_hidden_layers
1694 ])
1695 }
1696
1697 fn num_layers(&self, config: &str) -> Result<usize> {
1698 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1699
1700 Ok(cfg.num_hidden_layers)
1701 }
1702
1703 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1704 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1705
1706 let cfg = ModelConfigMetadata {
1707 max_seq_len: cfg.max_position_embeddings,
1708 num_layers: cfg.num_hidden_layers,
1709 hidden_size: cfg.hidden_size,
1710 num_kv_heads: cfg.num_key_value_heads,
1711 num_attn_heads: cfg.num_attention_heads,
1712 sliding_window: cfg.sliding_window,
1713 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1714 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1715 };
1716
1717 Ok(Box::new(cfg))
1718 }
1719}
1720
1721pub struct Gemma2Loader;
1727
1728impl NormalModelLoader for Gemma2Loader {
1729 fn load(
1730 &self,
1731 config: &str,
1732 vb: ShardedVarBuilder,
1733 normal_loading_metadata: NormalLoadingMetadata,
1734 attention_mechanism: AttentionImplementation,
1735 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1736 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1737
1738 Ok(Box::new(models::gemma2::Model::new(
1739 &cfg,
1740 vb,
1741 self.is_gptx(config)?,
1742 normal_loading_metadata,
1743 attention_mechanism,
1744 )?))
1745 }
1746 fn load_xlora(
1747 &self,
1748 config: &str,
1749 vb: ShardedVarBuilder,
1750 lora_config: &[((String, String), LoraConfig)],
1751 xlora_config: Option<XLoraConfig>,
1752 xlora_ordering: Ordering,
1753 normal_loading_metadata: NormalLoadingMetadata,
1754 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1755 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1756 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1757
1758 Ok(Box::new(xlora_models::XLoraGemma2::new(
1759 &cfg,
1760 vb,
1761 lora_config,
1762 xlora_config,
1763 xlora_ordering,
1764 self.is_gptx(config)?,
1765 normal_loading_metadata,
1766 preload_adapters,
1767 )?))
1768 }
1769 fn is_gptx(&self, _: &str) -> Result<bool> {
1770 Ok(true)
1771 }
1772 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1773 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1774
1775 Ok(Box::new(cfg))
1776 }
1777}
1778
1779impl IsqModelLoader for Gemma2Loader {
1780 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1781 Ok(vec![
1782 Regex::new(r"lm_head\.(weight|bias)$")?,
1783 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1785 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1786 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1787 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1788 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1790 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1791 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1792 ])
1793 }
1794 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1795 self.isq_layer_regexes(config)
1796 }
1797}
1798
1799impl DeviceMappedModelLoader for Gemma2Loader {
1800 fn mapped_max_act_size_elems(
1801 &self,
1802 config: &str,
1803 params: &AutoDeviceMapParams,
1804 prompt_chunksize: usize,
1805 ) -> Result<usize> {
1806 let AutoDeviceMapParams::Text {
1807 max_seq_len: _,
1808 max_batch_size,
1809 } = params
1810 else {
1811 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1812 };
1813
1814 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1815
1816 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1817 }
1818 fn non_mapped_max_act_size_elems(
1819 &self,
1820 _config: &str,
1821 _params: &AutoDeviceMapParams,
1822 ) -> Result<usize> {
1823 Ok(0)
1824 }
1825
1826 fn non_mapped_size_in_bytes(
1827 &self,
1828 config: &str,
1829 dtype: DType,
1830 weight_pack_factor: usize,
1831 ) -> Result<usize> {
1832 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1833
1834 let elems = {
1835 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1836 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1838 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1839 } else {
1840 0
1841 };
1842 let norm = cfg.hidden_size;
1843 embed_tokens + lm_head + norm
1844 };
1845 Ok(elems * dtype.size_in_bytes())
1846 }
1847
1848 fn layer_sizes_in_bytes(
1849 &self,
1850 config: &str,
1851 dtype: DType,
1852 weight_pack_factor: usize,
1853 ) -> Result<Vec<usize>> {
1854 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1855
1856 let per_layer_elems = {
1857 let input_layernorm = cfg.hidden_size;
1858 let post_attention_layernorm = cfg.hidden_size;
1859
1860 let size_in = cfg.hidden_size;
1861 let size_q = cfg.head_dim * cfg.num_attention_heads;
1862 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1863 let q_proj =
1864 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1865 let k_proj =
1866 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1867 let v_proj =
1868 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1869 let o_proj =
1870 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1871
1872 let h_size = cfg.hidden_size;
1873 let i_size = cfg.intermediate_size;
1874 let gate_proj = h_size * i_size / weight_pack_factor;
1875 let up_proj = h_size * i_size / weight_pack_factor;
1876 let down_proj = i_size * h_size / weight_pack_factor;
1877
1878 input_layernorm
1879 + post_attention_layernorm
1880 + q_proj
1881 + k_proj
1882 + v_proj
1883 + o_proj
1884 + gate_proj
1885 + up_proj
1886 + down_proj
1887 };
1888 Ok(vec![
1889 per_layer_elems * dtype.size_in_bytes();
1890 cfg.num_hidden_layers
1891 ])
1892 }
1893
1894 fn num_layers(&self, config: &str) -> Result<usize> {
1895 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1896
1897 Ok(cfg.num_hidden_layers)
1898 }
1899 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1900 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1901
1902 let cfg = ModelConfigMetadata {
1903 max_seq_len: cfg.max_position_embeddings,
1904 num_layers: cfg.num_hidden_layers,
1905 hidden_size: cfg.hidden_size,
1906 num_kv_heads: cfg.num_key_value_heads,
1907 num_attn_heads: cfg.num_attention_heads,
1908 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1910 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1911 };
1912
1913 Ok(Box::new(cfg))
1914 }
1915}
1916
1917pub struct Starcoder2Loader;
1923
1924impl NormalModelLoader for Starcoder2Loader {
1925 fn load(
1926 &self,
1927 config: &str,
1928 vb: ShardedVarBuilder,
1929 normal_loading_metadata: NormalLoadingMetadata,
1930 attention_mechanism: AttentionImplementation,
1931 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1932 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1933
1934 Ok(Box::new(models::starcoder2::Model::new(
1935 &cfg,
1936 vb,
1937 self.is_gptx(config)?,
1938 normal_loading_metadata,
1939 attention_mechanism,
1940 )?))
1941 }
1942 fn load_xlora(
1943 &self,
1944 config: &str,
1945 vb: ShardedVarBuilder,
1946 lora_config: &[((String, String), LoraConfig)],
1947 xlora_config: Option<XLoraConfig>,
1948 xlora_ordering: Ordering,
1949 normal_loading_metadata: NormalLoadingMetadata,
1950 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1951 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1952 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1953
1954 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1955 &cfg,
1956 vb,
1957 lora_config,
1958 xlora_config,
1959 xlora_ordering,
1960 self.is_gptx(config)?,
1961 normal_loading_metadata,
1962 preload_adapters,
1963 )?))
1964 }
1965 fn is_gptx(&self, _: &str) -> Result<bool> {
1966 Ok(true)
1967 }
1968 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1969 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1970
1971 Ok(Box::new(cfg))
1972 }
1973}
1974
1975impl IsqModelLoader for Starcoder2Loader {
1976 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1977 Ok(vec![
1978 Regex::new(r"lm_head\.(weight|bias)$")?,
1979 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1981 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1982 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1983 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1984 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1986 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
1987 ])
1988 }
1989 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1990 self.isq_layer_regexes(config)
1991 }
1992}
1993
1994impl DeviceMappedModelLoader for Starcoder2Loader {
1995 fn mapped_max_act_size_elems(
1996 &self,
1997 config: &str,
1998 params: &AutoDeviceMapParams,
1999 prompt_chunksize: usize,
2000 ) -> Result<usize> {
2001 let AutoDeviceMapParams::Text {
2002 max_seq_len: _,
2003 max_batch_size,
2004 } = params
2005 else {
2006 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2007 };
2008
2009 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2010
2011 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2012 }
2013 fn non_mapped_max_act_size_elems(
2014 &self,
2015 _config: &str,
2016 _params: &AutoDeviceMapParams,
2017 ) -> Result<usize> {
2018 Ok(0)
2019 }
2020
2021 fn non_mapped_size_in_bytes(
2022 &self,
2023 config: &str,
2024 dtype: DType,
2025 weight_pack_factor: usize,
2026 ) -> Result<usize> {
2027 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2028
2029 let elems = {
2030 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2031 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2033 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2034 } else {
2035 0
2036 };
2037 let norm = cfg.hidden_size + cfg.hidden_size;
2038 embed_tokens + lm_head + norm
2039 };
2040 Ok(elems * dtype.size_in_bytes())
2041 }
2042
2043 fn layer_sizes_in_bytes(
2044 &self,
2045 config: &str,
2046 dtype: DType,
2047 weight_pack_factor: usize,
2048 ) -> Result<Vec<usize>> {
2049 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2050
2051 let per_layer_elems = {
2052 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2053 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2054
2055 let size_in = cfg.hidden_size;
2056 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2057 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2058 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2059 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2060 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2061 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2062
2063 let h_size = cfg.hidden_size;
2064 let i_size = cfg.intermediate_size;
2065 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2066 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2067
2068 input_layernorm
2069 + post_attention_layernorm
2070 + q_proj
2071 + k_proj
2072 + v_proj
2073 + o_proj
2074 + fc1
2075 + fc2
2076 };
2077 Ok(vec![
2078 per_layer_elems * dtype.size_in_bytes();
2079 cfg.num_hidden_layers
2080 ])
2081 }
2082
2083 fn num_layers(&self, config: &str) -> Result<usize> {
2084 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2085
2086 Ok(cfg.num_hidden_layers)
2087 }
2088
2089 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2090 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2091
2092 let cfg = ModelConfigMetadata {
2093 max_seq_len: cfg.max_position_embeddings,
2094 num_layers: cfg.num_hidden_layers,
2095 hidden_size: cfg.hidden_size,
2096 num_kv_heads: cfg.num_key_value_heads,
2097 num_attn_heads: cfg.num_attention_heads,
2098 sliding_window: cfg.sliding_window,
2099 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2100 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2101 };
2102
2103 Ok(Box::new(cfg))
2104 }
2105}
2106
2107pub struct Phi3_5MoELoader;
2113
2114impl NormalModelLoader for Phi3_5MoELoader {
2115 fn load(
2116 &self,
2117 config: &str,
2118 vb: ShardedVarBuilder,
2119 normal_loading_metadata: NormalLoadingMetadata,
2120 attention_mechanism: AttentionImplementation,
2121 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2122 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2123
2124 Ok(Box::new(models::phi3_5_moe::Model::new(
2125 &cfg,
2126 vb,
2127 self.is_gptx(config)?,
2128 normal_loading_metadata,
2129 attention_mechanism,
2130 )?))
2131 }
2132 fn load_xlora(
2133 &self,
2134 config: &str,
2135 vb: ShardedVarBuilder,
2136 lora_config: &[((String, String), LoraConfig)],
2137 xlora_config: Option<XLoraConfig>,
2138 xlora_ordering: Ordering,
2139 normal_loading_metadata: NormalLoadingMetadata,
2140 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2141 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2142 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2143
2144 Ok(Box::new(xlora_models::XLoraPhi3::new(
2145 &cfg,
2146 vb,
2147 lora_config,
2148 xlora_config,
2149 xlora_ordering,
2150 self.is_gptx(config)?,
2151 normal_loading_metadata,
2152 preload_adapters,
2153 )?))
2154 }
2155 fn is_gptx(&self, _: &str) -> Result<bool> {
2156 Ok(true)
2157 }
2158 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2159 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2160
2161 Ok(Box::new(cfg))
2162 }
2163}
2164
2165impl IsqModelLoader for Phi3_5MoELoader {
2166 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2167 Ok(vec![
2168 Regex::new(r"lm_head\.(weight|bias)$")?,
2169 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2171 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2172 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2173 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2174 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2176 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2177 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2178 ])
2179 }
2180 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2181 self.isq_layer_regexes(config)
2182 }
2183
2184 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2185 Ok(vec![
2186 Regex::new(r"lm_head\.(weight|bias)$")?,
2187 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2189 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2190 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2191 ])
2192 }
2193 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2194 self.isq_layer_regexes_moqe(config)
2195 }
2196}
2197
2198impl DeviceMappedModelLoader for Phi3_5MoELoader {
2199 fn mapped_max_act_size_elems(
2200 &self,
2201 config: &str,
2202 params: &AutoDeviceMapParams,
2203 prompt_chunksize: usize,
2204 ) -> Result<usize> {
2205 let AutoDeviceMapParams::Text {
2206 max_seq_len: _,
2207 max_batch_size,
2208 } = params
2209 else {
2210 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2211 };
2212
2213 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2214
2215 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2216 }
2217 fn non_mapped_max_act_size_elems(
2218 &self,
2219 _config: &str,
2220 _params: &AutoDeviceMapParams,
2221 ) -> Result<usize> {
2222 Ok(0)
2223 }
2224
2225 fn non_mapped_size_in_bytes(
2226 &self,
2227 config: &str,
2228 dtype: DType,
2229 weight_pack_factor: usize,
2230 ) -> Result<usize> {
2231 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233 let elems = {
2234 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2235 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2237 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2238 } else {
2239 0
2240 };
2241 let norm = cfg.hidden_size;
2242 embed_tokens + lm_head + norm
2243 };
2244 Ok(elems * dtype.size_in_bytes())
2245 }
2246
2247 fn layer_sizes_in_bytes(
2248 &self,
2249 config: &str,
2250 dtype: DType,
2251 weight_pack_factor: usize,
2252 ) -> Result<Vec<usize>> {
2253 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2254
2255 let per_layer_elems = {
2256 let input_layernorm = cfg.hidden_size;
2257 let post_attention_layernorm = cfg.hidden_size;
2258
2259 let size_in = cfg.hidden_size;
2260 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2261 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2262 let q_proj =
2263 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2264 let k_proj =
2265 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2266 let v_proj =
2267 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2268 let o_proj =
2269 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2270
2271 let moe_block = {
2272 let gate = cfg.hidden_size * cfg.num_local_experts;
2273 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2275 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2276 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2277 gate + cfg.num_local_experts * w1
2278 + cfg.num_local_experts * w2
2279 + cfg.num_local_experts * w3
2280 };
2281
2282 input_layernorm
2283 + post_attention_layernorm
2284 + q_proj
2285 + k_proj
2286 + v_proj
2287 + o_proj
2288 + moe_block
2289 };
2290 Ok(vec![
2291 per_layer_elems * dtype.size_in_bytes();
2292 cfg.num_hidden_layers
2293 ])
2294 }
2295
2296 fn num_layers(&self, config: &str) -> Result<usize> {
2297 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2298
2299 Ok(cfg.num_hidden_layers)
2300 }
2301
2302 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2303 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2304
2305 let cfg = ModelConfigMetadata {
2306 max_seq_len: cfg.max_position_embeddings,
2307 num_layers: cfg.num_hidden_layers,
2308 hidden_size: cfg.hidden_size,
2309 num_kv_heads: cfg.num_key_value_heads,
2310 num_attn_heads: cfg.num_attention_heads,
2311 sliding_window: cfg.sliding_window,
2312 k_head_dim: cfg.head_dim(),
2313 v_head_dim: cfg.head_dim(),
2314 };
2315
2316 Ok(Box::new(cfg))
2317 }
2318}
2319
2320pub struct DeepSeekV2Loader;
2324
2325impl NormalModelLoader for DeepSeekV2Loader {
2326 fn load(
2327 &self,
2328 config: &str,
2329 vb: ShardedVarBuilder,
2330 normal_loading_metadata: NormalLoadingMetadata,
2331 attention_mechanism: AttentionImplementation,
2332 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2333 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2334
2335 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2336 &cfg,
2337 vb,
2338 self.is_gptx(config)?,
2339 normal_loading_metadata,
2340 attention_mechanism,
2341 )?))
2342 }
2343 fn load_xlora(
2344 &self,
2345 _config: &str,
2346 _vb: ShardedVarBuilder,
2347 _lora_config: &[((String, String), LoraConfig)],
2348 _xlora_config: Option<XLoraConfig>,
2349 _xlora_ordering: Ordering,
2350 _normal_loading_metadata: NormalLoadingMetadata,
2351 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2352 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2353 todo!()
2354 }
2355 fn is_gptx(&self, _: &str) -> Result<bool> {
2356 Ok(true)
2357 }
2358 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2359 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2360 Ok(Box::new(cfg))
2361 }
2362}
2363
2364impl IsqModelLoader for DeepSeekV2Loader {
2365 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2366 let mut data = vec![
2367 Regex::new(r"lm_head\.(weight|bias)$")?,
2368 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2370 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2371 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2372 ];
2373 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2374 if cfg.q_lora_rank.is_some() {
2375 data.extend(vec![
2376 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2377 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2378 ]);
2379 } else {
2380 data.push(Regex::new(
2381 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2382 )?);
2383 }
2384 for layer_idx in 0..cfg.num_hidden_layers {
2385 if cfg.n_routed_experts.is_some()
2386 && layer_idx >= cfg.first_k_dense_replace
2387 && layer_idx % cfg.moe_layer_freq == 0
2388 {
2389 for i in 0..cfg.n_routed_experts.unwrap() {
2390 data.extend(vec![
2391 Regex::new(&format!(
2392 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2393 ))?,
2394 Regex::new(&format!(
2395 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2396 ))?,
2397 Regex::new(&format!(
2398 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2399 ))?,
2400 ]);
2401 }
2402 if cfg.n_shared_experts.is_some() {
2403 data.extend(vec![
2404 Regex::new(&format!(
2405 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2406 ))?,
2407 Regex::new(&format!(
2408 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2409 ))?,
2410 Regex::new(&format!(
2411 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2412 ))?,
2413 ]);
2414 }
2415 } else {
2416 data.extend(vec![
2417 Regex::new(&format!(
2418 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2419 ))?,
2420 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2421 Regex::new(&format!(
2422 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2423 ))?,
2424 ]);
2425 };
2426 }
2427 Ok(data)
2428 }
2429 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2430 self.isq_layer_regexes(config)
2431 }
2432
2433 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2434 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2435 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2436 for layer_idx in 0..cfg.num_hidden_layers {
2437 if cfg.n_routed_experts.is_some()
2438 && layer_idx >= cfg.first_k_dense_replace
2439 && layer_idx % cfg.moe_layer_freq == 0
2440 {
2441 for i in 0..cfg.n_routed_experts.unwrap() {
2442 data.extend(vec![
2443 Regex::new(&format!(
2444 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2445 ))?,
2446 Regex::new(&format!(
2447 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2448 ))?,
2449 Regex::new(&format!(
2450 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2451 ))?,
2452 ]);
2453 }
2454 if cfg.n_shared_experts.is_some() {
2455 data.extend(vec![
2456 Regex::new(&format!(
2457 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2458 ))?,
2459 Regex::new(&format!(
2460 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2461 ))?,
2462 Regex::new(&format!(
2463 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2464 ))?,
2465 ]);
2466 }
2467 } else {
2468 data.extend(vec![
2469 Regex::new(&format!(
2470 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2471 ))?,
2472 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2473 Regex::new(&format!(
2474 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2475 ))?,
2476 ]);
2477 };
2478 }
2479 Ok(data)
2480 }
2481 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2482 self.isq_layer_regexes_moqe(config)
2483 }
2484}
2485
2486impl DeviceMappedModelLoader for DeepSeekV2Loader {
2487 fn mapped_max_act_size_elems(
2488 &self,
2489 config: &str,
2490 params: &AutoDeviceMapParams,
2491 prompt_chunksize: usize,
2492 ) -> Result<usize> {
2493 let AutoDeviceMapParams::Text {
2494 max_seq_len: _,
2495 max_batch_size,
2496 } = params
2497 else {
2498 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2499 };
2500
2501 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2502
2503 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2504 }
2505 fn non_mapped_max_act_size_elems(
2506 &self,
2507 _config: &str,
2508 _params: &AutoDeviceMapParams,
2509 ) -> Result<usize> {
2510 Ok(0)
2511 }
2512
2513 fn non_mapped_size_in_bytes(
2514 &self,
2515 config: &str,
2516 dtype: DType,
2517 weight_pack_factor: usize,
2518 ) -> Result<usize> {
2519 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2520 let elems = {
2521 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2522 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2524 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2525 } else {
2526 0
2527 };
2528 let norm = cfg.hidden_size;
2529 embed_tokens + lm_head + norm
2530 };
2531 Ok(elems * dtype.size_in_bytes())
2532 }
2533
2534 fn layer_sizes_in_bytes(
2535 &self,
2536 config: &str,
2537 dtype: DType,
2538 weight_pack_factor: usize,
2539 ) -> Result<Vec<usize>> {
2540 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2541 let mut per_layer_elems = Vec::new();
2542
2543 for layer_idx in 0..cfg.num_hidden_layers {
2544 let input_layernorm = cfg.hidden_size;
2545 let post_attention_layernorm = cfg.hidden_size;
2546
2547 let q_proj = match cfg.q_lora_rank {
2548 Some(lora_rank) => {
2549 let a = cfg.hidden_size * lora_rank;
2550 let norm = lora_rank;
2551 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2552 a + norm + b
2553 }
2554 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2555 };
2556 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2557 / weight_pack_factor
2558 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2559 let kv_a_layernorm = cfg.kv_lora_rank;
2560 let kv_b_proj = cfg.kv_lora_rank
2561 * cfg.num_attention_heads
2562 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2563 / weight_pack_factor;
2564 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2565 / weight_pack_factor
2566 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2567
2568 let moe_block = {
2569 let mut sum = 0;
2570 if cfg.n_routed_experts.is_some()
2571 && layer_idx >= cfg.first_k_dense_replace
2572 && layer_idx % cfg.moe_layer_freq == 0
2573 {
2574 let h_size = cfg.hidden_size;
2575 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2576 * cfg.n_routed_experts.unwrap();
2577 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2578 * cfg.n_routed_experts.unwrap();
2579 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2580 * cfg.n_routed_experts.unwrap();
2581 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2582 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2583 / weight_pack_factor;
2584 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2585 / weight_pack_factor;
2586 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2587 / weight_pack_factor;
2588 gate_proj + up_proj + down_proj
2589 } else {
2590 0
2591 };
2592 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2593 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2594 } else {
2595 let h_size = cfg.hidden_size;
2596 let i_size = cfg.intermediate_size;
2597 let gate_proj = h_size * i_size / weight_pack_factor;
2598 let up_proj = h_size * i_size / weight_pack_factor;
2599 let down_proj = i_size * h_size / weight_pack_factor;
2600 sum += gate_proj + up_proj + down_proj;
2601 }
2602 sum
2603 };
2604
2605 per_layer_elems.push(
2606 input_layernorm
2607 + post_attention_layernorm
2608 + q_proj
2609 + kv_a_layernorm
2610 + kv_a_proj_with_mqa
2611 + kv_b_proj
2612 + o_proj
2613 + moe_block,
2614 );
2615 }
2616
2617 Ok(per_layer_elems
2618 .into_iter()
2619 .map(|x| x * dtype.size_in_bytes())
2620 .collect())
2621 }
2622
2623 fn num_layers(&self, config: &str) -> Result<usize> {
2624 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2625 Ok(cfg.num_hidden_layers)
2626 }
2627
2628 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2629 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2630
2631 let cfg = ModelConfigMetadata {
2632 max_seq_len: cfg.max_position_embeddings,
2633 num_layers: cfg.num_hidden_layers,
2634 hidden_size: cfg.hidden_size,
2635 num_kv_heads: cfg.num_attention_heads,
2636 num_attn_heads: cfg.num_attention_heads,
2637 sliding_window: None,
2638 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2639 v_head_dim: cfg.v_head_dim,
2640 };
2641
2642 Ok(Box::new(cfg))
2643 }
2644}
2645
2646pub struct DeepSeekV3Loader;
2650
2651impl NormalModelLoader for DeepSeekV3Loader {
2652 fn load(
2653 &self,
2654 config: &str,
2655 vb: ShardedVarBuilder,
2656 normal_loading_metadata: NormalLoadingMetadata,
2657 attention_mechanism: AttentionImplementation,
2658 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2659 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2660 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2661 &cfg,
2662 vb,
2663 self.is_gptx(config)?,
2664 normal_loading_metadata,
2665 attention_mechanism,
2666 )?))
2667 }
2668 fn load_xlora(
2669 &self,
2670 _config: &str,
2671 _vb: ShardedVarBuilder,
2672 _lora_config: &[((String, String), LoraConfig)],
2673 _xlora_config: Option<XLoraConfig>,
2674 _xlora_ordering: Ordering,
2675 _normal_loading_metadata: NormalLoadingMetadata,
2676 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2677 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2678 todo!()
2679 }
2680 fn is_gptx(&self, _: &str) -> Result<bool> {
2681 Ok(true)
2682 }
2683 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2684 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2685 Ok(Box::new(cfg))
2686 }
2687}
2688
2689impl IsqModelLoader for DeepSeekV3Loader {
2690 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2691 let mut data = vec![
2692 Regex::new(r"lm_head\.(weight|bias)$")?,
2693 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2695 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2696 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2697 ];
2698 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2699 if cfg.q_lora_rank.is_some() {
2700 data.extend(vec![
2701 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2702 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2703 ]);
2704 } else {
2705 data.push(Regex::new(
2706 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2707 )?);
2708 }
2709 for layer_idx in 0..cfg.num_hidden_layers {
2710 if cfg.n_routed_experts.is_some()
2711 && layer_idx >= cfg.first_k_dense_replace
2712 && layer_idx % cfg.moe_layer_freq == 0
2713 {
2714 for i in 0..cfg.n_routed_experts.unwrap() {
2715 data.extend(vec![
2716 Regex::new(&format!(
2717 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2718 ))?,
2719 Regex::new(&format!(
2720 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2721 ))?,
2722 Regex::new(&format!(
2723 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2724 ))?,
2725 ]);
2726 }
2727 if cfg.n_shared_experts.is_some() {
2728 data.extend(vec![
2729 Regex::new(&format!(
2730 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2731 ))?,
2732 Regex::new(&format!(
2733 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2734 ))?,
2735 Regex::new(&format!(
2736 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2737 ))?,
2738 ]);
2739 }
2740 } else {
2741 data.extend(vec![
2742 Regex::new(&format!(
2743 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2744 ))?,
2745 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2746 Regex::new(&format!(
2747 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2748 ))?,
2749 ]);
2750 };
2751 }
2752 Ok(data)
2753 }
2754 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2755 self.isq_layer_regexes(config)
2756 }
2757
2758 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2759 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2760 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2761 for layer_idx in 0..cfg.num_hidden_layers {
2762 if cfg.n_routed_experts.is_some()
2763 && layer_idx >= cfg.first_k_dense_replace
2764 && layer_idx % cfg.moe_layer_freq == 0
2765 {
2766 for i in 0..cfg.n_routed_experts.unwrap() {
2767 data.extend(vec![
2768 Regex::new(&format!(
2769 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2770 ))?,
2771 Regex::new(&format!(
2772 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2773 ))?,
2774 Regex::new(&format!(
2775 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2776 ))?,
2777 ]);
2778 }
2779 if cfg.n_shared_experts.is_some() {
2780 data.extend(vec![
2781 Regex::new(&format!(
2782 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2783 ))?,
2784 Regex::new(&format!(
2785 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2786 ))?,
2787 Regex::new(&format!(
2788 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2789 ))?,
2790 ]);
2791 }
2792 } else {
2793 data.extend(vec![
2794 Regex::new(&format!(
2795 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2796 ))?,
2797 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2798 Regex::new(&format!(
2799 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2800 ))?,
2801 ]);
2802 };
2803 }
2804 Ok(data)
2805 }
2806 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2807 self.isq_layer_regexes_moqe(config)
2808 }
2809}
2810
2811impl DeviceMappedModelLoader for DeepSeekV3Loader {
2812 fn mapped_max_act_size_elems(
2813 &self,
2814 config: &str,
2815 params: &AutoDeviceMapParams,
2816 prompt_chunksize: usize,
2817 ) -> Result<usize> {
2818 let AutoDeviceMapParams::Text {
2819 max_seq_len: _,
2820 max_batch_size,
2821 } = params
2822 else {
2823 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2824 };
2825
2826 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2827
2828 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2829 }
2830 fn non_mapped_max_act_size_elems(
2831 &self,
2832 _config: &str,
2833 _params: &AutoDeviceMapParams,
2834 ) -> Result<usize> {
2835 Ok(0)
2836 }
2837
2838 fn non_mapped_size_in_bytes(
2839 &self,
2840 config: &str,
2841 dtype: DType,
2842 weight_pack_factor: usize,
2843 ) -> Result<usize> {
2844 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2845 let elems = {
2846 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2847 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2849 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2850 } else {
2851 0
2852 };
2853 let norm = cfg.hidden_size;
2854 embed_tokens + lm_head + norm
2855 };
2856 Ok(elems * dtype.size_in_bytes())
2857 }
2858
2859 fn layer_sizes_in_bytes(
2860 &self,
2861 config: &str,
2862 dtype: DType,
2863 weight_pack_factor: usize,
2864 ) -> Result<Vec<usize>> {
2865 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2866 let mut per_layer_elems = Vec::new();
2867
2868 for layer_idx in 0..cfg.num_hidden_layers {
2869 let input_layernorm = cfg.hidden_size;
2870 let post_attention_layernorm = cfg.hidden_size;
2871
2872 let q_proj = match cfg.q_lora_rank {
2873 Some(lora_rank) => {
2874 let a = cfg.hidden_size * lora_rank;
2875 let norm = lora_rank;
2876 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2877 a + norm + b
2878 }
2879 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2880 };
2881 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2882 / weight_pack_factor
2883 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2884 let kv_a_layernorm = cfg.kv_lora_rank;
2885 let kv_b_proj = cfg.kv_lora_rank
2886 * cfg.num_attention_heads
2887 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2888 / weight_pack_factor;
2889 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2890 / weight_pack_factor
2891 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2892
2893 let moe_block = {
2894 let mut sum = 0;
2895 if cfg.n_routed_experts.is_some()
2896 && layer_idx >= cfg.first_k_dense_replace
2897 && layer_idx % cfg.moe_layer_freq == 0
2898 {
2899 let h_size = cfg.hidden_size;
2900 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2901 * cfg.n_routed_experts.unwrap();
2902 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2903 * cfg.n_routed_experts.unwrap();
2904 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2905 * cfg.n_routed_experts.unwrap();
2906 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2907 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2908 / weight_pack_factor;
2909 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2910 / weight_pack_factor;
2911 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2912 / weight_pack_factor;
2913 gate_proj + up_proj + down_proj
2914 } else {
2915 0
2916 };
2917 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2918 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2919 } else {
2920 let h_size = cfg.hidden_size;
2921 let i_size = cfg.intermediate_size;
2922 let gate_proj = h_size * i_size / weight_pack_factor;
2923 let up_proj = h_size * i_size / weight_pack_factor;
2924 let down_proj = i_size * h_size / weight_pack_factor;
2925 sum += gate_proj + up_proj + down_proj;
2926 }
2927 sum
2928 };
2929
2930 per_layer_elems.push(
2931 input_layernorm
2932 + post_attention_layernorm
2933 + q_proj
2934 + kv_a_layernorm
2935 + kv_a_proj_with_mqa
2936 + kv_b_proj
2937 + o_proj
2938 + moe_block,
2939 );
2940 }
2941
2942 Ok(per_layer_elems
2943 .into_iter()
2944 .map(|x| x * dtype.size_in_bytes())
2945 .collect())
2946 }
2947
2948 fn num_layers(&self, config: &str) -> Result<usize> {
2949 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2950 Ok(cfg.num_hidden_layers)
2951 }
2952
2953 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2954 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2955
2956 let cfg = ModelConfigMetadata {
2957 max_seq_len: cfg.max_position_embeddings,
2958 num_layers: cfg.num_hidden_layers,
2959 hidden_size: cfg.hidden_size,
2960 num_kv_heads: cfg.num_attention_heads,
2961 num_attn_heads: cfg.num_attention_heads,
2962 sliding_window: None,
2963 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2964 v_head_dim: cfg.v_head_dim,
2965 };
2966
2967 Ok(Box::new(cfg))
2968 }
2969}
2970
2971pub struct Qwen3Loader;
2975
2976impl NormalModelLoader for Qwen3Loader {
2977 fn load(
2978 &self,
2979 config: &str,
2980 vb: ShardedVarBuilder,
2981 normal_loading_metadata: NormalLoadingMetadata,
2982 attention_mechanism: AttentionImplementation,
2983 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2984 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2985
2986 Ok(Box::new(models::qwen3::Model::new(
2987 &cfg,
2988 vb,
2989 self.is_gptx(config)?,
2990 normal_loading_metadata,
2991 attention_mechanism,
2992 )?))
2993 }
2994 fn load_xlora(
2995 &self,
2996 _config: &str,
2997 _vb: ShardedVarBuilder,
2998 _lora_config: &[((String, String), LoraConfig)],
2999 _xlora_config: Option<XLoraConfig>,
3000 _xlora_ordering: Ordering,
3001 _normal_loading_metadata: NormalLoadingMetadata,
3002 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3003 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3004 todo!()
3005 }
3006 fn is_gptx(&self, _: &str) -> Result<bool> {
3007 Ok(true)
3008 }
3009 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3010 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3011
3012 Ok(Box::new(cfg))
3013 }
3014}
3015
3016impl IsqModelLoader for Qwen3Loader {
3017 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3018 Ok(vec![
3019 Regex::new(r"lm_head\.(weight|bias)$")?,
3020 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3022 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3023 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3024 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3025 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3027 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3028 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3029 ])
3030 }
3031 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3032 self.isq_layer_regexes(config)
3033 }
3034}
3035
3036impl DeviceMappedModelLoader for Qwen3Loader {
3037 fn mapped_max_act_size_elems(
3038 &self,
3039 config: &str,
3040 params: &AutoDeviceMapParams,
3041 prompt_chunksize: usize,
3042 ) -> Result<usize> {
3043 let AutoDeviceMapParams::Text {
3044 max_seq_len: _,
3045 max_batch_size,
3046 } = params
3047 else {
3048 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3049 };
3050
3051 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3052
3053 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3054 }
3055 fn non_mapped_max_act_size_elems(
3056 &self,
3057 _config: &str,
3058 _params: &AutoDeviceMapParams,
3059 ) -> Result<usize> {
3060 Ok(0)
3061 }
3062
3063 fn non_mapped_size_in_bytes(
3064 &self,
3065 config: &str,
3066 dtype: DType,
3067 weight_pack_factor: usize,
3068 ) -> Result<usize> {
3069 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3070 let elems = {
3071 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3072 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3074 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3075 } else {
3076 0
3077 };
3078 let norm = cfg.hidden_size;
3079 embed_tokens + lm_head + norm
3080 };
3081 Ok(elems * dtype.size_in_bytes())
3082 }
3083
3084 fn layer_sizes_in_bytes(
3085 &self,
3086 config: &str,
3087 dtype: DType,
3088 weight_pack_factor: usize,
3089 ) -> Result<Vec<usize>> {
3090 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3091 let per_layer_elems = {
3092 let input_layernorm = cfg.hidden_size;
3093 let post_attention_layernorm = cfg.hidden_size;
3094
3095 let size_in = cfg.hidden_size;
3096 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3097 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3098 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3099 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3100 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3101 let o_proj = size_q * size_in / weight_pack_factor;
3102
3103 let h_size = cfg.hidden_size;
3104 let i_size = cfg.intermediate_size;
3105 let gate_proj = h_size * i_size / weight_pack_factor;
3106 let up_proj = h_size * i_size / weight_pack_factor;
3107 let down_proj = i_size * h_size / weight_pack_factor;
3108
3109 let q_norm = cfg.head_dim();
3110 let k_norm = cfg.head_dim();
3111
3112 input_layernorm
3113 + post_attention_layernorm
3114 + q_proj
3115 + k_proj
3116 + v_proj
3117 + o_proj
3118 + gate_proj
3119 + up_proj
3120 + down_proj
3121 + q_norm
3122 + k_norm
3123 };
3124 Ok(vec![
3125 per_layer_elems * dtype.size_in_bytes();
3126 cfg.num_hidden_layers
3127 ])
3128 }
3129
3130 fn num_layers(&self, config: &str) -> Result<usize> {
3131 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3132 Ok(cfg.num_hidden_layers)
3133 }
3134
3135 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3136 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3137
3138 let cfg = ModelConfigMetadata {
3139 max_seq_len: cfg.max_position_embeddings,
3140 num_layers: cfg.num_hidden_layers,
3141 hidden_size: cfg.hidden_size,
3142 num_kv_heads: cfg.num_key_value_heads,
3143 num_attn_heads: cfg.num_attention_heads,
3144 sliding_window: cfg.sliding_window,
3145 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3146 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3147 };
3148
3149 Ok(Box::new(cfg))
3150 }
3151}
3152
3153pub struct GLM4Loader;
3157
3158impl NormalModelLoader for GLM4Loader {
3159 fn load(
3160 &self,
3161 config: &str,
3162 vb: ShardedVarBuilder,
3163 normal_loading_metadata: NormalLoadingMetadata,
3164 attention_mechanism: AttentionImplementation,
3165 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3166 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3167
3168 Ok(Box::new(models::glm4::Model::new(
3169 &cfg,
3170 vb,
3171 self.is_gptx(config)?,
3172 normal_loading_metadata,
3173 attention_mechanism,
3174 )?))
3175 }
3176 fn load_xlora(
3177 &self,
3178 _config: &str,
3179 _vb: ShardedVarBuilder,
3180 _lora_config: &[((String, String), LoraConfig)],
3181 _xlora_config: Option<XLoraConfig>,
3182 _xlora_ordering: Ordering,
3183 _normal_loading_metadata: NormalLoadingMetadata,
3184 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3185 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3186 todo!()
3187 }
3188 fn is_gptx(&self, _: &str) -> Result<bool> {
3189 Ok(true)
3190 }
3191 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3192 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3193
3194 Ok(Box::new(cfg))
3195 }
3196}
3197
3198impl IsqModelLoader for GLM4Loader {
3199 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3200 Ok(vec![
3201 Regex::new(r"lm_head\.(weight|bias)$")?,
3202 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3204 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3205 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3206 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3207 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3209 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3210 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3211 ])
3212 }
3213 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3214 self.isq_layer_regexes(config)
3215 }
3216}
3217
3218impl DeviceMappedModelLoader for GLM4Loader {
3219 fn mapped_max_act_size_elems(
3220 &self,
3221 config: &str,
3222 params: &AutoDeviceMapParams,
3223 prompt_chunksize: usize,
3224 ) -> Result<usize> {
3225 let AutoDeviceMapParams::Text {
3226 max_seq_len: _,
3227 max_batch_size,
3228 } = params
3229 else {
3230 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3231 };
3232
3233 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3234
3235 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3236 }
3237 fn non_mapped_max_act_size_elems(
3238 &self,
3239 _config: &str,
3240 _params: &AutoDeviceMapParams,
3241 ) -> Result<usize> {
3242 Ok(0)
3243 }
3244
3245 fn non_mapped_size_in_bytes(
3246 &self,
3247 config: &str,
3248 dtype: DType,
3249 weight_pack_factor: usize,
3250 ) -> Result<usize> {
3251 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3252 let elems = {
3253 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3254 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3256 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3257 } else {
3258 0
3259 };
3260 let norm = cfg.hidden_size;
3261 embed_tokens + lm_head + norm
3262 };
3263 Ok(elems * dtype.size_in_bytes())
3264 }
3265
3266 fn layer_sizes_in_bytes(
3267 &self,
3268 config: &str,
3269 dtype: DType,
3270 weight_pack_factor: usize,
3271 ) -> Result<Vec<usize>> {
3272 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3273 let per_layer_elems = {
3274 let input_layernorm = cfg.hidden_size;
3275 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3278 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3279 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3280 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3281 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3282 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3283 let o_proj = size_q * size_in / weight_pack_factor;
3284
3285 let h_size = cfg.hidden_size;
3286 let i_size = cfg.intermediate_size;
3287 let gate_proj = h_size * i_size / weight_pack_factor;
3288 let up_proj = h_size * i_size / weight_pack_factor;
3289 let down_proj = i_size * h_size / weight_pack_factor;
3290
3291 input_layernorm
3292 + post_attention_layernorm
3293 + q_proj
3294 + k_proj
3295 + v_proj
3296 + o_proj
3297 + gate_proj
3298 + up_proj
3299 + down_proj
3300 };
3301 Ok(vec![
3302 per_layer_elems * dtype.size_in_bytes();
3303 cfg.num_hidden_layers
3304 ])
3305 }
3306
3307 fn num_layers(&self, config: &str) -> Result<usize> {
3308 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3309 Ok(cfg.num_hidden_layers)
3310 }
3311
3312 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3313 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3314
3315 let cfg = ModelConfigMetadata {
3316 max_seq_len: cfg.max_position_embeddings,
3317 num_layers: cfg.num_hidden_layers,
3318 hidden_size: cfg.hidden_size,
3319 num_kv_heads: cfg.num_key_value_heads,
3320 num_attn_heads: cfg.num_attention_heads,
3321 sliding_window: cfg.sliding_window,
3322 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3323 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3324 };
3325
3326 Ok(Box::new(cfg))
3327 }
3328}
3329
3330pub struct Qwen3MoELoader;
3334
3335impl NormalModelLoader for Qwen3MoELoader {
3336 fn load(
3337 &self,
3338 config: &str,
3339 vb: ShardedVarBuilder,
3340 normal_loading_metadata: NormalLoadingMetadata,
3341 attention_mechanism: AttentionImplementation,
3342 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3343 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3344
3345 Ok(Box::new(models::qwen3_moe::Model::new(
3346 &cfg,
3347 vb,
3348 self.is_gptx(config)?,
3349 normal_loading_metadata,
3350 attention_mechanism,
3351 )?))
3352 }
3353 fn load_xlora(
3354 &self,
3355 _config: &str,
3356 _vb: ShardedVarBuilder,
3357 _lora_config: &[((String, String), LoraConfig)],
3358 _xlora_config: Option<XLoraConfig>,
3359 _xlora_ordering: Ordering,
3360 _normal_loading_metadata: NormalLoadingMetadata,
3361 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3362 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3363 todo!()
3364 }
3365 fn is_gptx(&self, _: &str) -> Result<bool> {
3366 Ok(true)
3367 }
3368 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3369 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3370
3371 Ok(Box::new(cfg))
3372 }
3373}
3374
3375impl IsqModelLoader for Qwen3MoELoader {
3376 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3377 Ok(vec![
3378 Regex::new(r"lm_head\.(weight|bias)$")?,
3379 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3381 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3382 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3383 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3384 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3386 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3387 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3388 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3390 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3391 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3392 ])
3393 }
3394 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3395 self.isq_layer_regexes(config)
3396 }
3397 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3398 self.isq_layer_regexes_moqe(config)
3399 }
3400}
3401
3402impl DeviceMappedModelLoader for Qwen3MoELoader {
3403 fn mapped_max_act_size_elems(
3404 &self,
3405 config: &str,
3406 params: &AutoDeviceMapParams,
3407 prompt_chunksize: usize,
3408 ) -> Result<usize> {
3409 let AutoDeviceMapParams::Text {
3410 max_seq_len: _,
3411 max_batch_size,
3412 } = params
3413 else {
3414 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3415 };
3416
3417 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3418
3419 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3420 }
3421 fn non_mapped_max_act_size_elems(
3422 &self,
3423 _config: &str,
3424 _params: &AutoDeviceMapParams,
3425 ) -> Result<usize> {
3426 Ok(0)
3427 }
3428
3429 fn non_mapped_size_in_bytes(
3430 &self,
3431 config: &str,
3432 dtype: DType,
3433 weight_pack_factor: usize,
3434 ) -> Result<usize> {
3435 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3436 let elems = {
3437 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3438 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3440 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3441 } else {
3442 0
3443 };
3444 let norm = cfg.hidden_size;
3445 embed_tokens + lm_head + norm
3446 };
3447 Ok(elems * dtype.size_in_bytes())
3448 }
3449
3450 fn layer_sizes_in_bytes(
3451 &self,
3452 config: &str,
3453 dtype: DType,
3454 weight_pack_factor: usize,
3455 ) -> Result<Vec<usize>> {
3456 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3457
3458 let mut layer_sizes_in_bytes = Vec::new();
3459 for layer_idx in 0..cfg.num_hidden_layers {
3460 let input_layernorm = cfg.hidden_size;
3461 let post_attention_layernorm = cfg.hidden_size;
3462
3463 let size_in = cfg.hidden_size;
3464 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3465 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3466 let q_proj = size_in * size_q / weight_pack_factor;
3467 let k_proj = size_in * size_kv / weight_pack_factor;
3468 let v_proj = size_in * size_kv / weight_pack_factor;
3469 let o_proj = size_q * size_in / weight_pack_factor;
3470
3471 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3472 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3473 {
3474 let gate_size = cfg.hidden_size * cfg.num_experts;
3475 let expert_size = {
3476 let h_size = cfg.hidden_size;
3477 let i_size = cfg.moe_intermediate_size;
3478 let gate_proj = h_size * i_size / weight_pack_factor;
3479 let up_proj = h_size * i_size / weight_pack_factor;
3480 let down_proj = i_size * h_size / weight_pack_factor;
3481 gate_proj + up_proj + down_proj
3482 };
3483 expert_size * cfg.num_experts + gate_size
3484 } else {
3485 let h_size = cfg.hidden_size;
3486 let i_size = cfg.intermediate_size;
3487 let gate_proj = h_size * i_size / weight_pack_factor;
3488 let up_proj = h_size * i_size / weight_pack_factor;
3489 let down_proj = i_size * h_size / weight_pack_factor;
3490 gate_proj + up_proj + down_proj
3491 };
3492
3493 let q_norm = cfg.head_dim();
3494 let k_norm = cfg.head_dim();
3495
3496 let size_elems = input_layernorm
3497 + post_attention_layernorm
3498 + q_proj
3499 + k_proj
3500 + v_proj
3501 + o_proj
3502 + mlp_size
3503 + q_norm
3504 + k_norm;
3505
3506 let size_in_bytes = size_elems * dtype.size_in_bytes();
3507 layer_sizes_in_bytes.push(size_in_bytes);
3508 }
3509
3510 Ok(layer_sizes_in_bytes)
3511 }
3512
3513 fn num_layers(&self, config: &str) -> Result<usize> {
3514 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3515 Ok(cfg.num_hidden_layers)
3516 }
3517
3518 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3519 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3520
3521 let cfg = ModelConfigMetadata {
3522 max_seq_len: cfg.max_position_embeddings,
3523 num_layers: cfg.num_hidden_layers,
3524 hidden_size: cfg.hidden_size,
3525 num_kv_heads: cfg.num_key_value_heads,
3526 num_attn_heads: cfg.num_attention_heads,
3527 sliding_window: cfg.sliding_window,
3528 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3529 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3530 };
3531
3532 Ok(Box::new(cfg))
3533 }
3534}
3535
3536pub struct SmolLm3Loader;
3542
3543impl NormalModelLoader for SmolLm3Loader {
3544 fn load(
3545 &self,
3546 config: &str,
3547 vb: ShardedVarBuilder,
3548 normal_loading_metadata: NormalLoadingMetadata,
3549 attention_mechanism: AttentionImplementation,
3550 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3551 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3552
3553 Ok(Box::new(models::smollm3::SmolLm3::new(
3554 &cfg,
3555 vb,
3556 self.is_gptx(config)?,
3557 normal_loading_metadata,
3558 attention_mechanism,
3559 )?))
3560 }
3561 fn load_xlora(
3562 &self,
3563 _config: &str,
3564 _vb: ShardedVarBuilder,
3565 _lora_config: &[((String, String), LoraConfig)],
3566 _xlora_config: Option<XLoraConfig>,
3567 _xlora_ordering: Ordering,
3568 _normal_loading_metadata: NormalLoadingMetadata,
3569 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3570 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3571 todo!()
3572 }
3573 fn is_gptx(&self, _: &str) -> Result<bool> {
3574 Ok(true)
3575 }
3576 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3577 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3578 Ok(Box::new(cfg))
3579 }
3580}
3581
3582impl IsqModelLoader for SmolLm3Loader {
3583 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3584 Ok(vec![
3585 Regex::new(r"lm_head\.(weight|bias)$")?,
3586 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3588 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3589 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3590 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3591 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3593 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3594 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3595 ])
3596 }
3597 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3598 self.isq_layer_regexes(config)
3599 }
3600}
3601
3602impl DeviceMappedModelLoader for SmolLm3Loader {
3603 fn mapped_max_act_size_elems(
3604 &self,
3605 config: &str,
3606 params: &AutoDeviceMapParams,
3607 prompt_chunksize: usize,
3608 ) -> Result<usize> {
3609 let AutoDeviceMapParams::Text {
3610 max_seq_len: _,
3611 max_batch_size,
3612 } = params
3613 else {
3614 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3615 };
3616
3617 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3618
3619 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3620 }
3621 fn non_mapped_max_act_size_elems(
3622 &self,
3623 _config: &str,
3624 _params: &AutoDeviceMapParams,
3625 ) -> Result<usize> {
3626 Ok(0)
3627 }
3628
3629 fn non_mapped_size_in_bytes(
3630 &self,
3631 config: &str,
3632 dtype: DType,
3633 weight_pack_factor: usize,
3634 ) -> Result<usize> {
3635 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3636
3637 let elems = {
3638 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3639 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3641 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3642 } else {
3643 0
3644 };
3645 let norm = cfg.hidden_size;
3646 embed_tokens + lm_head + norm
3647 };
3648 Ok(elems * dtype.size_in_bytes())
3649 }
3650
3651 fn layer_sizes_in_bytes(
3652 &self,
3653 config: &str,
3654 dtype: DType,
3655 weight_pack_factor: usize,
3656 ) -> Result<Vec<usize>> {
3657 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3658
3659 let per_layer_elems = {
3660 let input_layernorm = cfg.hidden_size;
3661 let post_attention_layernorm = cfg.hidden_size;
3662
3663 let size_in = cfg.hidden_size;
3664 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3665 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3666 let q_proj = size_in * size_q / weight_pack_factor;
3667 let k_proj = size_in * size_kv / weight_pack_factor;
3668 let v_proj = size_in * size_kv / weight_pack_factor;
3669 let o_proj = size_q * size_in / weight_pack_factor;
3670
3671 let h_size = cfg.hidden_size;
3672 let i_size = cfg.intermediate_size;
3673 let gate_proj = h_size * i_size / weight_pack_factor;
3674 let up_proj = h_size * i_size / weight_pack_factor;
3675 let down_proj = i_size * h_size / weight_pack_factor;
3676
3677 input_layernorm
3678 + post_attention_layernorm
3679 + q_proj
3680 + k_proj
3681 + v_proj
3682 + o_proj
3683 + gate_proj
3684 + up_proj
3685 + down_proj
3686 };
3687 Ok(vec![
3688 per_layer_elems * dtype.size_in_bytes();
3689 cfg.num_hidden_layers
3690 ])
3691 }
3692
3693 fn num_layers(&self, config: &str) -> Result<usize> {
3694 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3695
3696 Ok(cfg.num_hidden_layers)
3697 }
3698 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3699 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3700
3701 let cfg = ModelConfigMetadata {
3702 max_seq_len: cfg.max_position_embeddings,
3703 num_layers: cfg.num_hidden_layers,
3704 hidden_size: cfg.hidden_size,
3705 num_kv_heads: cfg.num_key_value_heads,
3706 num_attn_heads: cfg.num_attention_heads,
3707 sliding_window: None,
3708 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3709 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3710 };
3711
3712 Ok(Box::new(cfg))
3713 }
3714}