1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "qwen3moe")]
175 Qwen3Moe,
176 #[serde(rename = "smollm3")]
177 SmolLm3,
178}
179
180impl NormalLoaderType {
182 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183 match name {
184 "MistralForCausalLM" => Ok(Self::Mistral),
185 "MixtralForCausalLM" => Ok(Self::Mixtral),
186 "GemmaForCausalLM" => Ok(Self::Gemma),
187 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
188 "PhiForCausalLM" => Ok(Self::Phi2),
189 "Phi3ForCausalLM" => Ok(Self::Phi3),
190 "LlamaForCausalLM" => Ok(Self::Llama),
191 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
192 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
193 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
194 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
195 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
196 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
197 "Glm4ForCausalLM" => Ok(Self::GLM4),
198 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
199 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
200 other => anyhow::bail!(
201 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
202 ),
203 }
204 }
205}
206
207impl FromStr for NormalLoaderType {
208 type Err = String;
209 fn from_str(s: &str) -> Result<Self, Self::Err> {
210 match s {
211 "mistral" => Ok(Self::Mistral),
212 "gemma" => Ok(Self::Gemma),
213 "mixtral" => Ok(Self::Mixtral),
214 "llama" => Ok(Self::Llama),
215 "phi2" => Ok(Self::Phi2),
216 "phi3" => Ok(Self::Phi3),
217 "qwen2" => Ok(Self::Qwen2),
218 "gemma2" => Ok(Self::Gemma2),
219 "starcoder2" => Ok(Self::Starcoder2),
220 "phi3.5moe" => Ok(Self::Phi3_5MoE),
221 "deepseekv2" => Ok(Self::DeepSeekV2),
222 "deepseekv3" => Ok(Self::DeepSeekV3),
223 "qwen3" => Ok(Self::Qwen3),
224 "glm4" => Ok(Self::GLM4),
225 "qwen3moe" => Ok(Self::Qwen3Moe),
226 "smollm3" => Ok(Self::SmolLm3),
227 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
228 }
229 }
230}
231
232impl Display for NormalLoaderType {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Self::Gemma => write!(f, "gemma"),
236 Self::Gemma2 => write!(f, "gemma2"),
237 Self::Llama => write!(f, "llama"),
238 Self::Mistral => write!(f, "mistral"),
239 Self::Mixtral => write!(f, "mixtral"),
240 Self::Phi2 => write!(f, "phi2"),
241 Self::Phi3 => write!(f, "phi3"),
242 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
243 Self::Qwen2 => write!(f, "qwen2"),
244 Self::Starcoder2 => write!(f, "starcoder2"),
245 Self::DeepSeekV2 => write!(f, "deepseekv2"),
246 Self::DeepSeekV3 => write!(f, "deepseekv3"),
247 Self::Qwen3 => write!(f, "qwen3"),
248 Self::GLM4 => write!(f, "glm4"),
249 Self::Qwen3Moe => write!(f, "qwen3moe"),
250 Self::SmolLm3 => write!(f, "smollm3"),
251 }
252 }
253}
254
255macro_rules! bias_if {
256 ($cond:expr, $size:expr) => {
257 if $cond {
258 $size
259 } else {
260 0
261 }
262 };
263}
264
265pub struct AutoNormalLoader;
267
268#[derive(Deserialize)]
269struct AutoNormalLoaderConfig {
270 architectures: Vec<String>,
271}
272
273impl AutoNormalLoader {
274 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
275 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
276 if auto_cfg.architectures.len() != 1 {
277 anyhow::bail!("Expected to have one name for `architectures` config field.")
278 }
279
280 let name = &auto_cfg.architectures[0];
281
282 let tp = NormalLoaderType::from_causal_lm_name(name)?;
283
284 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
285
286 match tp {
287 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
288 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
289 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
290 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
291 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
292 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
293 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
294 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
295 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
296 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
297 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
298 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
299 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
300 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
301 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
302 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
303 }
304 }
305}
306
307impl NormalModelLoader for AutoNormalLoader {
308 fn load(
309 &self,
310 config: &str,
311 vb: ShardedVarBuilder,
312 normal_loading_metadata: NormalLoadingMetadata,
313 attention_mechanism: AttentionImplementation,
314 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
315 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
316 }
317 fn load_xlora(
318 &self,
319 config: &str,
320 vb: ShardedVarBuilder,
321 lora_config: &[((String, String), LoraConfig)],
322 xlora_config: Option<XLoraConfig>,
323 xlora_ordering: Ordering,
324 normal_loading_metadata: NormalLoadingMetadata,
325 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
326 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327 Self::get_loader(config)?.load_xlora(
328 config,
329 vb,
330 lora_config,
331 xlora_config,
332 xlora_ordering,
333 normal_loading_metadata,
334 preload_adapters,
335 )
336 }
337 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
338 Self::get_loader(config)?.get_config_repr(config)
339 }
340 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
341 Self::get_loader(config)?.supports_paged_attention(config)
342 }
343 fn is_gptx(&self, config: &str) -> Result<bool> {
344 Self::get_loader(config)?.is_gptx(config)
345 }
346}
347
348impl IsqModelLoader for AutoNormalLoader {
349 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
350 Self::get_loader(config)?.immediate_isq_predicates(config)
351 }
352 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
353 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
354 }
355 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
356 Self::get_loader(config)?.isq_layer_regexes(config)
357 }
358 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
360 }
361}
362
363impl DeviceMappedModelLoader for AutoNormalLoader {
364 fn non_mapped_size_in_bytes(
365 &self,
366 config: &str,
367 dtype: DType,
368 weight_pack_factor: usize,
369 _matformer_config: Option<&MatformerSliceConfig>,
370 ) -> Result<usize> {
371 Self::get_loader(config)?.non_mapped_size_in_bytes(
372 config,
373 dtype,
374 weight_pack_factor,
375 _matformer_config,
376 )
377 }
378 fn num_layers(&self, config: &str) -> Result<usize> {
379 Self::get_loader(config)?.num_layers(config)
380 }
381 fn layer_sizes_in_bytes(
382 &self,
383 config: &str,
384 dtype: DType,
385 weight_pack_factor: usize,
386 _matformer_config: Option<&MatformerSliceConfig>,
387 ) -> Result<Vec<usize>> {
388 Self::get_loader(config)?.layer_sizes_in_bytes(
389 config,
390 dtype,
391 weight_pack_factor,
392 _matformer_config,
393 )
394 }
395 fn mapped_max_act_size_elems(
396 &self,
397 config: &str,
398 params: &super::AutoDeviceMapParams,
399 ) -> Result<usize> {
400 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
401 }
402 fn non_mapped_max_act_size_elems(
403 &self,
404 _config: &str,
405 _params: &AutoDeviceMapParams,
406 ) -> Result<usize> {
407 Ok(0)
408 }
409 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
410 Self::get_loader(config)?.model_config(config)
411 }
412}
413
414pub struct MistralLoader;
417
418impl NormalModelLoader for MistralLoader {
419 fn load(
420 &self,
421 config: &str,
422 vb: ShardedVarBuilder,
423 normal_loading_metadata: NormalLoadingMetadata,
424 attention_mechanism: AttentionImplementation,
425 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
426 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
427 Ok(Box::new(models::mistral::Model::new(
428 &cfg,
429 vb,
430 self.is_gptx(config)?,
431 normal_loading_metadata,
432 attention_mechanism,
433 )?))
434 }
435 fn load_xlora(
436 &self,
437 config: &str,
438 vb: ShardedVarBuilder,
439 lora_config: &[((String, String), LoraConfig)],
440 xlora_config: Option<XLoraConfig>,
441 xlora_ordering: Ordering,
442 normal_loading_metadata: NormalLoadingMetadata,
443 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
444 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
445 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
446 Ok(Box::new(xlora_models::XLoraMistral::new(
447 &cfg,
448 vb,
449 lora_config,
450 xlora_config,
451 xlora_ordering,
452 self.is_gptx(config)?,
453 normal_loading_metadata,
454 preload_adapters,
455 )?))
456 }
457 fn is_gptx(&self, _: &str) -> Result<bool> {
458 Ok(true)
459 }
460 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
461 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
462 Ok(Box::new(cfg))
463 }
464}
465
466impl IsqModelLoader for MistralLoader {
467 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
468 Ok(vec![
469 Regex::new(r"lm_head\.(weight|bias)$")?,
470 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
472 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
473 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
474 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
475 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
477 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
478 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
479 ])
480 }
481 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
482 self.isq_layer_regexes(config)
483 }
484}
485
486impl DeviceMappedModelLoader for MistralLoader {
487 fn mapped_max_act_size_elems(
488 &self,
489 config: &str,
490 params: &AutoDeviceMapParams,
491 ) -> Result<usize> {
492 let AutoDeviceMapParams::Text {
493 max_seq_len,
494 max_batch_size,
495 } = params
496 else {
497 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
498 };
499
500 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
501
502 Ok(
503 max_batch_size
504 * cfg.num_attention_heads
505 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
506 )
507 }
508 fn non_mapped_max_act_size_elems(
509 &self,
510 _config: &str,
511 _params: &AutoDeviceMapParams,
512 ) -> Result<usize> {
513 Ok(0)
514 }
515
516 fn non_mapped_size_in_bytes(
517 &self,
518 config: &str,
519 dtype: DType,
520 weight_pack_factor: usize,
521 _matformer_config: Option<&MatformerSliceConfig>,
522 ) -> Result<usize> {
523 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
524
525 let elems = {
526 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
527 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
529 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
530 } else {
531 0
532 };
533 let norm = cfg.hidden_size;
534 embed_tokens + lm_head + norm
535 };
536 Ok(elems * dtype.size_in_bytes())
537 }
538
539 fn layer_sizes_in_bytes(
540 &self,
541 config: &str,
542 dtype: DType,
543 weight_pack_factor: usize,
544 _matformer_config: Option<&MatformerSliceConfig>,
545 ) -> Result<Vec<usize>> {
546 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
547
548 let per_layer_elems = {
549 let input_layernorm = cfg.hidden_size;
550 let post_attention_layernorm = cfg.hidden_size;
551
552 let size_in = cfg.hidden_size;
553 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
554 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
555 let q_proj = size_in * size_q / weight_pack_factor;
556 let k_proj = size_in * size_kv / weight_pack_factor;
557 let v_proj = size_in * size_kv / weight_pack_factor;
558 let o_proj = size_q * size_in / weight_pack_factor;
559
560 let h_size = cfg.hidden_size;
561 let i_size = cfg.intermediate_size;
562 let gate_proj = h_size * i_size / weight_pack_factor;
563 let up_proj = h_size * i_size / weight_pack_factor;
564 let down_proj = i_size * h_size / weight_pack_factor;
565
566 input_layernorm
567 + post_attention_layernorm
568 + q_proj
569 + k_proj
570 + v_proj
571 + o_proj
572 + gate_proj
573 + up_proj
574 + down_proj
575 };
576 Ok(vec![
577 per_layer_elems * dtype.size_in_bytes();
578 cfg.num_hidden_layers
579 ])
580 }
581
582 fn num_layers(&self, config: &str) -> Result<usize> {
583 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
584 Ok(cfg.num_hidden_layers)
585 }
586
587 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
588 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
589
590 let cfg = ModelConfigMetadata {
591 max_seq_len: cfg.max_position_embeddings,
592 num_layers: cfg.num_hidden_layers,
593 hidden_size: cfg.hidden_size,
594 num_kv_heads: cfg.num_key_value_heads,
595 num_attn_heads: cfg.num_attention_heads,
596 sliding_window: cfg.sliding_window,
597 k_head_dim: cfg.head_dim(),
598 v_head_dim: cfg.head_dim(),
599 };
600
601 Ok(Box::new(cfg))
602 }
603}
604
605pub struct GemmaLoader;
611
612impl NormalModelLoader for GemmaLoader {
613 fn load(
614 &self,
615 config: &str,
616 vb: ShardedVarBuilder,
617 normal_loading_metadata: NormalLoadingMetadata,
618 attention_mechanism: AttentionImplementation,
619 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
620 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
621
622 Ok(Box::new(models::gemma::Model::new(
623 &cfg,
624 vb,
625 self.is_gptx(config)?,
626 normal_loading_metadata,
627 attention_mechanism,
628 )?))
629 }
630 fn load_xlora(
631 &self,
632 config: &str,
633 vb: ShardedVarBuilder,
634 lora_config: &[((String, String), LoraConfig)],
635 xlora_config: Option<XLoraConfig>,
636 xlora_ordering: Ordering,
637 normal_loading_metadata: NormalLoadingMetadata,
638 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
639 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
640 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
641
642 Ok(Box::new(xlora_models::XLoraGemma::new(
643 &cfg,
644 vb,
645 lora_config,
646 xlora_config,
647 xlora_ordering,
648 self.is_gptx(config)?,
649 normal_loading_metadata,
650 preload_adapters,
651 )?))
652 }
653 fn is_gptx(&self, _: &str) -> Result<bool> {
654 Ok(true)
655 }
656 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
657 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
658 Ok(Box::new(cfg))
659 }
660}
661
662impl IsqModelLoader for GemmaLoader {
663 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
664 Ok(vec![
665 Regex::new(r"lm_head\.(weight|bias)$")?,
666 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
668 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
669 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
670 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
671 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
673 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
674 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
675 ])
676 }
677 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
678 self.isq_layer_regexes(config)
679 }
680}
681
682impl DeviceMappedModelLoader for GemmaLoader {
683 fn mapped_max_act_size_elems(
684 &self,
685 config: &str,
686 params: &AutoDeviceMapParams,
687 ) -> Result<usize> {
688 let AutoDeviceMapParams::Text {
689 max_seq_len,
690 max_batch_size,
691 } = params
692 else {
693 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
694 };
695
696 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
697
698 Ok(
699 max_batch_size
700 * cfg.num_attention_heads
701 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
702 )
703 }
704 fn non_mapped_max_act_size_elems(
705 &self,
706 _config: &str,
707 _params: &AutoDeviceMapParams,
708 ) -> Result<usize> {
709 Ok(0)
710 }
711
712 fn non_mapped_size_in_bytes(
713 &self,
714 config: &str,
715 dtype: DType,
716 weight_pack_factor: usize,
717 _matformer_config: Option<&MatformerSliceConfig>,
718 ) -> Result<usize> {
719 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
720
721 let elems = {
722 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
723 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
725 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
726 } else {
727 0
728 };
729 let norm = cfg.hidden_size;
730 embed_tokens + lm_head + norm
731 };
732 Ok(elems * dtype.size_in_bytes())
733 }
734
735 fn layer_sizes_in_bytes(
736 &self,
737 config: &str,
738 dtype: DType,
739 weight_pack_factor: usize,
740 _matformer_config: Option<&MatformerSliceConfig>,
741 ) -> Result<Vec<usize>> {
742 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
743
744 let per_layer_elems = {
745 let input_layernorm = cfg.hidden_size;
746 let post_attention_layernorm = cfg.hidden_size;
747
748 let size_in = cfg.hidden_size;
749 let size_q = cfg.head_dim * cfg.num_attention_heads;
750 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
751 let q_proj =
752 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
753 let k_proj =
754 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
755 let v_proj =
756 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
757 let o_proj =
758 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
759
760 let h_size = cfg.hidden_size;
761 let i_size = cfg.intermediate_size;
762 let gate_proj = h_size * i_size / weight_pack_factor;
763 let up_proj = h_size * i_size / weight_pack_factor;
764 let down_proj = i_size * h_size / weight_pack_factor;
765
766 input_layernorm
767 + post_attention_layernorm
768 + q_proj
769 + k_proj
770 + v_proj
771 + o_proj
772 + gate_proj
773 + up_proj
774 + down_proj
775 };
776 Ok(vec![
777 per_layer_elems * dtype.size_in_bytes();
778 cfg.num_hidden_layers
779 ])
780 }
781
782 fn num_layers(&self, config: &str) -> Result<usize> {
783 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
784 Ok(cfg.num_hidden_layers)
785 }
786
787 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
788 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
789
790 let cfg = ModelConfigMetadata {
791 max_seq_len: cfg.max_position_embeddings,
792 num_layers: cfg.num_hidden_layers,
793 hidden_size: cfg.hidden_size,
794 num_kv_heads: cfg.num_key_value_heads,
795 num_attn_heads: cfg.num_attention_heads,
796 sliding_window: None,
797 k_head_dim: cfg.head_dim,
798 v_head_dim: cfg.head_dim,
799 };
800
801 Ok(Box::new(cfg))
802 }
803}
804
805pub struct LlamaLoader;
811
812impl NormalModelLoader for LlamaLoader {
813 fn load(
814 &self,
815 config: &str,
816 vb: ShardedVarBuilder,
817 normal_loading_metadata: NormalLoadingMetadata,
818 attention_mechanism: AttentionImplementation,
819 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
820 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
821
822 Ok(Box::new(models::llama::Llama::new(
823 &cfg,
824 vb,
825 self.is_gptx(config)?,
826 normal_loading_metadata,
827 attention_mechanism,
828 )?))
829 }
830 fn load_xlora(
831 &self,
832 config: &str,
833 vb: ShardedVarBuilder,
834 lora_config: &[((String, String), LoraConfig)],
835 xlora_config: Option<XLoraConfig>,
836 xlora_ordering: Ordering,
837 normal_loading_metadata: NormalLoadingMetadata,
838 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
839 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
840 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
841
842 Ok(Box::new(xlora_models::XLoraLlama::new(
843 &cfg,
844 vb,
845 lora_config,
846 xlora_config,
847 xlora_ordering,
848 self.is_gptx(config)?,
849 normal_loading_metadata,
850 preload_adapters,
851 )?))
852 }
853 fn is_gptx(&self, _: &str) -> Result<bool> {
854 Ok(true)
855 }
856 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
857 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
858 Ok(Box::new(cfg))
859 }
860}
861
862impl IsqModelLoader for LlamaLoader {
863 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
864 Ok(vec![
865 Regex::new(r"lm_head\.(weight|bias)$")?,
866 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
868 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
869 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
870 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
871 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
873 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
874 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
875 ])
876 }
877 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
878 self.isq_layer_regexes(config)
879 }
880}
881
882impl DeviceMappedModelLoader for LlamaLoader {
883 fn mapped_max_act_size_elems(
884 &self,
885 config: &str,
886 params: &AutoDeviceMapParams,
887 ) -> Result<usize> {
888 let AutoDeviceMapParams::Text {
889 max_seq_len,
890 max_batch_size,
891 } = params
892 else {
893 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
894 };
895
896 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
897
898 Ok(
899 max_batch_size
900 * cfg.num_attention_heads
901 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
902 )
903 }
904 fn non_mapped_max_act_size_elems(
905 &self,
906 _config: &str,
907 _params: &AutoDeviceMapParams,
908 ) -> Result<usize> {
909 Ok(0)
910 }
911
912 fn non_mapped_size_in_bytes(
913 &self,
914 config: &str,
915 dtype: DType,
916 weight_pack_factor: usize,
917 _matformer_config: Option<&MatformerSliceConfig>,
918 ) -> Result<usize> {
919 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
920
921 let elems = {
922 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
923 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
925 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
926 } else {
927 0
928 };
929 let norm = cfg.hidden_size;
930 embed_tokens + lm_head + norm
931 };
932 Ok(elems * dtype.size_in_bytes())
933 }
934
935 fn layer_sizes_in_bytes(
936 &self,
937 config: &str,
938 dtype: DType,
939 weight_pack_factor: usize,
940 _matformer_config: Option<&MatformerSliceConfig>,
941 ) -> Result<Vec<usize>> {
942 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
943
944 let per_layer_elems = {
945 let input_layernorm = cfg.hidden_size;
946 let post_attention_layernorm = cfg.hidden_size;
947
948 let size_in = cfg.hidden_size;
949 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
950 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
951 let q_proj = size_in * size_q / weight_pack_factor;
952 let k_proj = size_in * size_kv / weight_pack_factor;
953 let v_proj = size_in * size_kv / weight_pack_factor;
954 let o_proj = size_q * size_in / weight_pack_factor;
955
956 let h_size = cfg.hidden_size;
957 let i_size = cfg.intermediate_size;
958 let gate_proj = h_size * i_size / weight_pack_factor;
959 let up_proj = h_size * i_size / weight_pack_factor;
960 let down_proj = i_size * h_size / weight_pack_factor;
961
962 input_layernorm
963 + post_attention_layernorm
964 + q_proj
965 + k_proj
966 + v_proj
967 + o_proj
968 + gate_proj
969 + up_proj
970 + down_proj
971 };
972 Ok(vec![
973 per_layer_elems * dtype.size_in_bytes();
974 cfg.num_hidden_layers
975 ])
976 }
977
978 fn num_layers(&self, config: &str) -> Result<usize> {
979 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
980
981 Ok(cfg.num_hidden_layers)
982 }
983 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
984 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
985
986 let cfg = ModelConfigMetadata {
987 max_seq_len: cfg.max_position_embeddings,
988 num_layers: cfg.num_hidden_layers,
989 hidden_size: cfg.hidden_size,
990 num_kv_heads: cfg.num_key_value_heads,
991 num_attn_heads: cfg.num_attention_heads,
992 sliding_window: None,
993 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
994 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
995 };
996
997 Ok(Box::new(cfg))
998 }
999}
1000
1001pub struct MixtralLoader;
1004
1005impl NormalModelLoader for MixtralLoader {
1006 fn load(
1007 &self,
1008 config: &str,
1009 vb: ShardedVarBuilder,
1010 normal_loading_metadata: NormalLoadingMetadata,
1011 attention_mechanism: AttentionImplementation,
1012 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1013 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1014
1015 Ok(Box::new(models::mixtral::Model::new(
1016 &cfg,
1017 vb,
1018 self.is_gptx(config)?,
1019 normal_loading_metadata,
1020 attention_mechanism,
1021 )?))
1022 }
1023 fn load_xlora(
1024 &self,
1025 config: &str,
1026 vb: ShardedVarBuilder,
1027 lora_config: &[((String, String), LoraConfig)],
1028 xlora_config: Option<XLoraConfig>,
1029 xlora_ordering: Ordering,
1030 normal_loading_metadata: NormalLoadingMetadata,
1031 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1032 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1033 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1034
1035 Ok(Box::new(xlora_models::XLoraMixtral::new(
1036 &cfg,
1037 vb,
1038 lora_config,
1039 xlora_config,
1040 xlora_ordering,
1041 self.is_gptx(config)?,
1042 normal_loading_metadata,
1043 preload_adapters,
1044 )?))
1045 }
1046 fn is_gptx(&self, _: &str) -> Result<bool> {
1047 Ok(true)
1048 }
1049 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1050 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1051
1052 Ok(Box::new(cfg))
1053 }
1054}
1055
1056impl IsqModelLoader for MixtralLoader {
1057 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1058 Ok(vec![
1059 Regex::new(r"lm_head\.(weight|bias)$")?,
1060 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1062 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1063 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1064 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1065 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1067 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1068 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1069 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1070 ])
1071 }
1072 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1073 self.isq_layer_regexes(config)
1074 }
1075}
1076
1077impl DeviceMappedModelLoader for MixtralLoader {
1078 fn mapped_max_act_size_elems(
1079 &self,
1080 config: &str,
1081 params: &AutoDeviceMapParams,
1082 ) -> Result<usize> {
1083 let AutoDeviceMapParams::Text {
1084 max_seq_len,
1085 max_batch_size,
1086 } = params
1087 else {
1088 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1089 };
1090
1091 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1092
1093 Ok(
1094 max_batch_size
1095 * cfg.num_attention_heads
1096 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1097 )
1098 }
1099 fn non_mapped_max_act_size_elems(
1100 &self,
1101 _config: &str,
1102 _params: &AutoDeviceMapParams,
1103 ) -> Result<usize> {
1104 Ok(0)
1105 }
1106
1107 fn non_mapped_size_in_bytes(
1108 &self,
1109 config: &str,
1110 dtype: DType,
1111 weight_pack_factor: usize,
1112 _matformer_config: Option<&MatformerSliceConfig>,
1113 ) -> Result<usize> {
1114 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1115
1116 let elems = {
1117 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1118 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1120 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1121 } else {
1122 0
1123 };
1124 let norm = cfg.hidden_size;
1125 embed_tokens + lm_head + norm
1126 };
1127 Ok(elems * dtype.size_in_bytes())
1128 }
1129
1130 fn layer_sizes_in_bytes(
1131 &self,
1132 config: &str,
1133 dtype: DType,
1134 weight_pack_factor: usize,
1135 _matformer_config: Option<&MatformerSliceConfig>,
1136 ) -> Result<Vec<usize>> {
1137 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1138
1139 let per_layer_elems = {
1140 let input_layernorm = cfg.hidden_size;
1141 let post_attention_layernorm = cfg.hidden_size;
1142
1143 let size_in = cfg.hidden_size;
1144 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1145 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1146 let q_proj = size_in * size_q / weight_pack_factor;
1147 let k_proj = size_in * size_kv / weight_pack_factor;
1148 let v_proj = size_in * size_kv / weight_pack_factor;
1149 let o_proj = size_q * size_in / weight_pack_factor;
1150
1151 let moe_block = {
1152 let gate = cfg.hidden_size * cfg.num_local_experts;
1153 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1155 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1156 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1157 gate + cfg.num_local_experts * w1
1158 + cfg.num_local_experts * w2
1159 + cfg.num_local_experts * w3
1160 };
1161
1162 input_layernorm
1163 + post_attention_layernorm
1164 + q_proj
1165 + k_proj
1166 + v_proj
1167 + o_proj
1168 + moe_block
1169 };
1170 Ok(vec![
1171 per_layer_elems * dtype.size_in_bytes();
1172 cfg.num_hidden_layers
1173 ])
1174 }
1175
1176 fn num_layers(&self, config: &str) -> Result<usize> {
1177 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1178
1179 Ok(cfg.num_hidden_layers)
1180 }
1181
1182 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1183 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1184
1185 let cfg = ModelConfigMetadata {
1186 max_seq_len: cfg.max_position_embeddings,
1187 num_layers: cfg.num_hidden_layers,
1188 hidden_size: cfg.hidden_size,
1189 num_kv_heads: cfg.num_key_value_heads,
1190 num_attn_heads: cfg.num_attention_heads,
1191 sliding_window: cfg.sliding_window,
1192 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1193 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1194 };
1195
1196 Ok(Box::new(cfg))
1197 }
1198}
1199
1200pub struct Phi2Loader;
1206
1207impl NormalModelLoader for Phi2Loader {
1208 fn load(
1209 &self,
1210 config: &str,
1211 vb: ShardedVarBuilder,
1212 normal_loading_metadata: NormalLoadingMetadata,
1213 attention_mechanism: AttentionImplementation,
1214 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1215 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1216
1217 Ok(Box::new(models::phi2::Model::new(
1218 &cfg,
1219 vb,
1220 self.is_gptx(config)?,
1221 normal_loading_metadata,
1222 attention_mechanism,
1223 )?))
1224 }
1225 fn load_xlora(
1226 &self,
1227 config: &str,
1228 vb: ShardedVarBuilder,
1229 lora_config: &[((String, String), LoraConfig)],
1230 xlora_config: Option<XLoraConfig>,
1231 xlora_ordering: Ordering,
1232 normal_loading_metadata: NormalLoadingMetadata,
1233 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1234 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1235 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1236
1237 Ok(Box::new(xlora_models::XLoraPhi2::new(
1238 &cfg,
1239 vb,
1240 lora_config,
1241 xlora_config,
1242 xlora_ordering,
1243 self.is_gptx(config)?,
1244 normal_loading_metadata,
1245 preload_adapters,
1246 )?))
1247 }
1248 fn is_gptx(&self, _: &str) -> Result<bool> {
1249 Ok(true)
1250 }
1251 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1252 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1253
1254 Ok(Box::new(cfg))
1255 }
1256}
1257
1258impl IsqModelLoader for Phi2Loader {
1259 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1260 Ok(vec![
1261 Regex::new(r"lm_head\.(weight|bias)$")?,
1262 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1264 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1265 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1266 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1267 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1269 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1270 ])
1271 }
1272 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1273 self.isq_layer_regexes(config)
1274 }
1275}
1276
1277impl DeviceMappedModelLoader for Phi2Loader {
1278 fn mapped_max_act_size_elems(
1279 &self,
1280 config: &str,
1281 params: &AutoDeviceMapParams,
1282 ) -> Result<usize> {
1283 let AutoDeviceMapParams::Text {
1284 max_seq_len,
1285 max_batch_size,
1286 } = params
1287 else {
1288 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1289 };
1290
1291 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1292
1293 Ok(
1294 max_batch_size
1295 * cfg.num_attention_heads
1296 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1297 )
1298 }
1299 fn non_mapped_max_act_size_elems(
1300 &self,
1301 _config: &str,
1302 _params: &AutoDeviceMapParams,
1303 ) -> Result<usize> {
1304 Ok(0)
1305 }
1306
1307 fn non_mapped_size_in_bytes(
1308 &self,
1309 config: &str,
1310 dtype: DType,
1311 weight_pack_factor: usize,
1312 _matformer_config: Option<&MatformerSliceConfig>,
1313 ) -> Result<usize> {
1314 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1315
1316 let elems = {
1317 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1318 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1320 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1321 } else {
1322 0
1323 };
1324 let norm = cfg.hidden_size;
1325 embed_tokens + lm_head + norm
1326 };
1327 Ok(elems * dtype.size_in_bytes())
1328 }
1329
1330 fn layer_sizes_in_bytes(
1331 &self,
1332 config: &str,
1333 dtype: DType,
1334 weight_pack_factor: usize,
1335 _matformer_config: Option<&MatformerSliceConfig>,
1336 ) -> Result<Vec<usize>> {
1337 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1338
1339 let per_layer_elems = {
1340 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1341
1342 let size_in = cfg.hidden_size;
1343 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1344 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1345 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1346 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1347 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1348 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1349 let (q_norm, k_norm) = if cfg.qk_layernorm {
1350 (cfg.head_dim(), cfg.head_dim())
1351 } else {
1352 (0, 0)
1353 };
1354
1355 let h_size = cfg.hidden_size;
1356 let i_size = cfg.intermediate_size;
1357 let fc1 = h_size * i_size / weight_pack_factor;
1358 let fc2 = h_size * i_size / weight_pack_factor;
1359
1360 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1361 };
1362 Ok(vec![
1363 per_layer_elems * dtype.size_in_bytes();
1364 cfg.num_hidden_layers
1365 ])
1366 }
1367
1368 fn num_layers(&self, config: &str) -> Result<usize> {
1369 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1370
1371 Ok(cfg.num_hidden_layers)
1372 }
1373
1374 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1375 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377 let cfg = ModelConfigMetadata {
1378 max_seq_len: cfg.max_position_embeddings,
1379 num_layers: cfg.num_hidden_layers,
1380 hidden_size: cfg.hidden_size,
1381 num_kv_heads: cfg.num_key_value_heads(),
1382 num_attn_heads: cfg.num_attention_heads,
1383 sliding_window: None,
1384 k_head_dim: cfg.head_dim(),
1385 v_head_dim: cfg.head_dim(),
1386 };
1387
1388 Ok(Box::new(cfg))
1389 }
1390}
1391
1392pub struct Phi3Loader;
1398
1399impl NormalModelLoader for Phi3Loader {
1400 fn load(
1401 &self,
1402 config: &str,
1403 vb: ShardedVarBuilder,
1404 normal_loading_metadata: NormalLoadingMetadata,
1405 attention_mechanism: AttentionImplementation,
1406 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1407 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1408
1409 Ok(Box::new(models::phi3::Model::new(
1410 &cfg,
1411 vb,
1412 self.is_gptx(config)?,
1413 normal_loading_metadata,
1414 attention_mechanism,
1415 )?))
1416 }
1417 fn load_xlora(
1418 &self,
1419 config: &str,
1420 vb: ShardedVarBuilder,
1421 lora_config: &[((String, String), LoraConfig)],
1422 xlora_config: Option<XLoraConfig>,
1423 xlora_ordering: Ordering,
1424 normal_loading_metadata: NormalLoadingMetadata,
1425 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1426 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1427 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1428
1429 Ok(Box::new(xlora_models::XLoraPhi3::new(
1430 &cfg,
1431 vb,
1432 lora_config,
1433 xlora_config,
1434 xlora_ordering,
1435 self.is_gptx(config)?,
1436 normal_loading_metadata,
1437 preload_adapters,
1438 )?))
1439 }
1440 fn is_gptx(&self, _: &str) -> Result<bool> {
1441 Ok(true)
1442 }
1443 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1444 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1445
1446 Ok(Box::new(cfg))
1447 }
1448}
1449
1450impl IsqModelLoader for Phi3Loader {
1451 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1452 Ok(vec![
1453 Regex::new(r"lm_head\.(weight|bias)$")?,
1454 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1456 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1457 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1459 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1460 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1461 ])
1462 }
1463 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1464 self.isq_layer_regexes(config)
1465 }
1466}
1467
1468impl DeviceMappedModelLoader for Phi3Loader {
1469 fn mapped_max_act_size_elems(
1470 &self,
1471 config: &str,
1472 params: &AutoDeviceMapParams,
1473 ) -> Result<usize> {
1474 let AutoDeviceMapParams::Text {
1475 max_seq_len,
1476 max_batch_size,
1477 } = params
1478 else {
1479 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1480 };
1481
1482 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1483
1484 Ok(
1485 max_batch_size
1486 * cfg.num_attention_heads
1487 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1488 )
1489 }
1490 fn non_mapped_max_act_size_elems(
1491 &self,
1492 _config: &str,
1493 _params: &AutoDeviceMapParams,
1494 ) -> Result<usize> {
1495 Ok(0)
1496 }
1497
1498 fn non_mapped_size_in_bytes(
1499 &self,
1500 config: &str,
1501 dtype: DType,
1502 weight_pack_factor: usize,
1503 _matformer_config: Option<&MatformerSliceConfig>,
1504 ) -> Result<usize> {
1505 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1506
1507 let elems = {
1508 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1509 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1511 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1512 } else {
1513 0
1514 };
1515 let norm = cfg.hidden_size;
1516 embed_tokens + lm_head + norm
1517 };
1518 Ok(elems * dtype.size_in_bytes())
1519 }
1520
1521 fn layer_sizes_in_bytes(
1522 &self,
1523 config: &str,
1524 dtype: DType,
1525 weight_pack_factor: usize,
1526 _matformer_config: Option<&MatformerSliceConfig>,
1527 ) -> Result<Vec<usize>> {
1528 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1529
1530 let per_layer_elems = {
1531 let input_layernorm = cfg.hidden_size;
1532 let post_attention_layernorm = cfg.hidden_size;
1533
1534 let size_in = cfg.hidden_size;
1535 let head_dim = cfg.head_dim();
1536 let op_size =
1537 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1538 let qkv_proj = size_in * op_size / weight_pack_factor;
1539 let o_proj =
1540 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1541
1542 let h_size = cfg.hidden_size;
1543 let i_size = cfg.intermediate_size;
1544 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1545 let down_proj = h_size * i_size / weight_pack_factor;
1546
1547 input_layernorm
1548 + post_attention_layernorm
1549 + qkv_proj
1550 + o_proj
1551 + gate_up_proj
1552 + down_proj
1553 };
1554 Ok(vec![
1555 per_layer_elems * dtype.size_in_bytes();
1556 cfg.num_hidden_layers
1557 ])
1558 }
1559
1560 fn num_layers(&self, config: &str) -> Result<usize> {
1561 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1562
1563 Ok(cfg.num_hidden_layers)
1564 }
1565
1566 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1567 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569 let cfg = ModelConfigMetadata {
1570 max_seq_len: cfg.max_position_embeddings,
1571 num_layers: cfg.num_hidden_layers,
1572 hidden_size: cfg.hidden_size,
1573 num_kv_heads: cfg.num_key_value_heads,
1574 num_attn_heads: cfg.num_attention_heads,
1575 sliding_window: cfg.sliding_window,
1576 k_head_dim: cfg.head_dim(),
1577 v_head_dim: cfg.head_dim(),
1578 };
1579
1580 Ok(Box::new(cfg))
1581 }
1582}
1583
1584pub struct Qwen2Loader;
1590
1591impl NormalModelLoader for Qwen2Loader {
1592 fn load(
1593 &self,
1594 config: &str,
1595 vb: ShardedVarBuilder,
1596 normal_loading_metadata: NormalLoadingMetadata,
1597 attention_mechanism: AttentionImplementation,
1598 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1599 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1600
1601 Ok(Box::new(models::qwen2::Model::new(
1602 &cfg,
1603 vb,
1604 self.is_gptx(config)?,
1605 normal_loading_metadata,
1606 attention_mechanism,
1607 )?))
1608 }
1609 fn load_xlora(
1610 &self,
1611 _config: &str,
1612 _vb: ShardedVarBuilder,
1613 _lora_config: &[((String, String), LoraConfig)],
1614 _xlora_config: Option<XLoraConfig>,
1615 _xlora_ordering: Ordering,
1616 _normal_loading_metadata: NormalLoadingMetadata,
1617 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1618 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1619 todo!()
1620 }
1621 fn is_gptx(&self, _: &str) -> Result<bool> {
1622 Ok(true)
1623 }
1624 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1625 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1626
1627 Ok(Box::new(cfg))
1628 }
1629}
1630
1631impl IsqModelLoader for Qwen2Loader {
1632 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1633 Ok(vec![
1634 Regex::new(r"lm_head\.(weight|bias)$")?,
1635 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1637 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1638 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1639 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1640 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1642 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1643 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1644 ])
1645 }
1646 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1647 self.isq_layer_regexes(config)
1648 }
1649}
1650
1651impl DeviceMappedModelLoader for Qwen2Loader {
1652 fn mapped_max_act_size_elems(
1653 &self,
1654 config: &str,
1655 params: &AutoDeviceMapParams,
1656 ) -> Result<usize> {
1657 let AutoDeviceMapParams::Text {
1658 max_seq_len,
1659 max_batch_size,
1660 } = params
1661 else {
1662 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1663 };
1664
1665 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1666
1667 Ok(
1668 max_batch_size
1669 * cfg.num_attention_heads
1670 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1671 )
1672 }
1673 fn non_mapped_max_act_size_elems(
1674 &self,
1675 _config: &str,
1676 _params: &AutoDeviceMapParams,
1677 ) -> Result<usize> {
1678 Ok(0)
1679 }
1680
1681 fn non_mapped_size_in_bytes(
1682 &self,
1683 config: &str,
1684 dtype: DType,
1685 weight_pack_factor: usize,
1686 _matformer_config: Option<&MatformerSliceConfig>,
1687 ) -> Result<usize> {
1688 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1689
1690 let elems = {
1691 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1692 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1694 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1695 } else {
1696 0
1697 };
1698 let norm = cfg.hidden_size;
1699 embed_tokens + lm_head + norm
1700 };
1701 Ok(elems * dtype.size_in_bytes())
1702 }
1703
1704 fn layer_sizes_in_bytes(
1705 &self,
1706 config: &str,
1707 dtype: DType,
1708 weight_pack_factor: usize,
1709 _matformer_config: Option<&MatformerSliceConfig>,
1710 ) -> Result<Vec<usize>> {
1711 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1712
1713 let per_layer_elems = {
1714 let input_layernorm = cfg.hidden_size;
1715 let post_attention_layernorm = cfg.hidden_size;
1716
1717 let size_in = cfg.hidden_size;
1718 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1719 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1720 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1721 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1722 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1723 let o_proj = size_q * size_in / weight_pack_factor;
1724
1725 let h_size = cfg.hidden_size;
1726 let i_size = cfg.intermediate_size;
1727 let gate_proj = h_size * i_size / weight_pack_factor;
1728 let up_proj = h_size * i_size / weight_pack_factor;
1729 let down_proj = i_size * h_size / weight_pack_factor;
1730
1731 input_layernorm
1732 + post_attention_layernorm
1733 + q_proj
1734 + k_proj
1735 + v_proj
1736 + o_proj
1737 + gate_proj
1738 + up_proj
1739 + down_proj
1740 };
1741 Ok(vec![
1742 per_layer_elems * dtype.size_in_bytes();
1743 cfg.num_hidden_layers
1744 ])
1745 }
1746
1747 fn num_layers(&self, config: &str) -> Result<usize> {
1748 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1749
1750 Ok(cfg.num_hidden_layers)
1751 }
1752
1753 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1754 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756 let cfg = ModelConfigMetadata {
1757 max_seq_len: cfg.max_position_embeddings,
1758 num_layers: cfg.num_hidden_layers,
1759 hidden_size: cfg.hidden_size,
1760 num_kv_heads: cfg.num_key_value_heads,
1761 num_attn_heads: cfg.num_attention_heads,
1762 sliding_window: cfg.sliding_window,
1763 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1764 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1765 };
1766
1767 Ok(Box::new(cfg))
1768 }
1769}
1770
1771pub struct Gemma2Loader;
1777
1778impl NormalModelLoader for Gemma2Loader {
1779 fn load(
1780 &self,
1781 config: &str,
1782 vb: ShardedVarBuilder,
1783 normal_loading_metadata: NormalLoadingMetadata,
1784 attention_mechanism: AttentionImplementation,
1785 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1786 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1787
1788 Ok(Box::new(models::gemma2::Model::new(
1789 &cfg,
1790 vb,
1791 self.is_gptx(config)?,
1792 normal_loading_metadata,
1793 attention_mechanism,
1794 )?))
1795 }
1796 fn load_xlora(
1797 &self,
1798 config: &str,
1799 vb: ShardedVarBuilder,
1800 lora_config: &[((String, String), LoraConfig)],
1801 xlora_config: Option<XLoraConfig>,
1802 xlora_ordering: Ordering,
1803 normal_loading_metadata: NormalLoadingMetadata,
1804 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1805 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1806 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1807
1808 Ok(Box::new(xlora_models::XLoraGemma2::new(
1809 &cfg,
1810 vb,
1811 lora_config,
1812 xlora_config,
1813 xlora_ordering,
1814 self.is_gptx(config)?,
1815 normal_loading_metadata,
1816 preload_adapters,
1817 )?))
1818 }
1819 fn is_gptx(&self, _: &str) -> Result<bool> {
1820 Ok(true)
1821 }
1822 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1823 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1824
1825 Ok(Box::new(cfg))
1826 }
1827}
1828
1829impl IsqModelLoader for Gemma2Loader {
1830 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1831 Ok(vec![
1832 Regex::new(r"lm_head\.(weight|bias)$")?,
1833 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1835 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1836 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1837 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1838 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1840 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1841 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1842 ])
1843 }
1844 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1845 self.isq_layer_regexes(config)
1846 }
1847}
1848
1849impl DeviceMappedModelLoader for Gemma2Loader {
1850 fn mapped_max_act_size_elems(
1851 &self,
1852 config: &str,
1853 params: &AutoDeviceMapParams,
1854 ) -> Result<usize> {
1855 let AutoDeviceMapParams::Text {
1856 max_seq_len,
1857 max_batch_size,
1858 } = params
1859 else {
1860 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1861 };
1862
1863 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1864
1865 Ok(
1866 max_batch_size
1867 * cfg.num_attention_heads
1868 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1869 )
1870 }
1871 fn non_mapped_max_act_size_elems(
1872 &self,
1873 _config: &str,
1874 _params: &AutoDeviceMapParams,
1875 ) -> Result<usize> {
1876 Ok(0)
1877 }
1878
1879 fn non_mapped_size_in_bytes(
1880 &self,
1881 config: &str,
1882 dtype: DType,
1883 weight_pack_factor: usize,
1884 _matformer_config: Option<&MatformerSliceConfig>,
1885 ) -> Result<usize> {
1886 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1887
1888 let elems = {
1889 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1890 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1892 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1893 } else {
1894 0
1895 };
1896 let norm = cfg.hidden_size;
1897 embed_tokens + lm_head + norm
1898 };
1899 Ok(elems * dtype.size_in_bytes())
1900 }
1901
1902 fn layer_sizes_in_bytes(
1903 &self,
1904 config: &str,
1905 dtype: DType,
1906 weight_pack_factor: usize,
1907 _matformer_config: Option<&MatformerSliceConfig>,
1908 ) -> Result<Vec<usize>> {
1909 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1910
1911 let per_layer_elems = {
1912 let input_layernorm = cfg.hidden_size;
1913 let post_attention_layernorm = cfg.hidden_size;
1914
1915 let size_in = cfg.hidden_size;
1916 let size_q = cfg.head_dim * cfg.num_attention_heads;
1917 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1918 let q_proj =
1919 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1920 let k_proj =
1921 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1922 let v_proj =
1923 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1924 let o_proj =
1925 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1926
1927 let h_size = cfg.hidden_size;
1928 let i_size = cfg.intermediate_size;
1929 let gate_proj = h_size * i_size / weight_pack_factor;
1930 let up_proj = h_size * i_size / weight_pack_factor;
1931 let down_proj = i_size * h_size / weight_pack_factor;
1932
1933 input_layernorm
1934 + post_attention_layernorm
1935 + q_proj
1936 + k_proj
1937 + v_proj
1938 + o_proj
1939 + gate_proj
1940 + up_proj
1941 + down_proj
1942 };
1943 Ok(vec![
1944 per_layer_elems * dtype.size_in_bytes();
1945 cfg.num_hidden_layers
1946 ])
1947 }
1948
1949 fn num_layers(&self, config: &str) -> Result<usize> {
1950 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1951
1952 Ok(cfg.num_hidden_layers)
1953 }
1954 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1955 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1956
1957 let cfg = ModelConfigMetadata {
1958 max_seq_len: cfg.max_position_embeddings,
1959 num_layers: cfg.num_hidden_layers,
1960 hidden_size: cfg.hidden_size,
1961 num_kv_heads: cfg.num_key_value_heads,
1962 num_attn_heads: cfg.num_attention_heads,
1963 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1965 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1966 };
1967
1968 Ok(Box::new(cfg))
1969 }
1970}
1971
1972pub struct Starcoder2Loader;
1978
1979impl NormalModelLoader for Starcoder2Loader {
1980 fn load(
1981 &self,
1982 config: &str,
1983 vb: ShardedVarBuilder,
1984 normal_loading_metadata: NormalLoadingMetadata,
1985 attention_mechanism: AttentionImplementation,
1986 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1987 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1988
1989 Ok(Box::new(models::starcoder2::Model::new(
1990 &cfg,
1991 vb,
1992 self.is_gptx(config)?,
1993 normal_loading_metadata,
1994 attention_mechanism,
1995 )?))
1996 }
1997 fn load_xlora(
1998 &self,
1999 config: &str,
2000 vb: ShardedVarBuilder,
2001 lora_config: &[((String, String), LoraConfig)],
2002 xlora_config: Option<XLoraConfig>,
2003 xlora_ordering: Ordering,
2004 normal_loading_metadata: NormalLoadingMetadata,
2005 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2006 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2007 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2008
2009 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2010 &cfg,
2011 vb,
2012 lora_config,
2013 xlora_config,
2014 xlora_ordering,
2015 self.is_gptx(config)?,
2016 normal_loading_metadata,
2017 preload_adapters,
2018 )?))
2019 }
2020 fn is_gptx(&self, _: &str) -> Result<bool> {
2021 Ok(true)
2022 }
2023 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2024 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2025
2026 Ok(Box::new(cfg))
2027 }
2028}
2029
2030impl IsqModelLoader for Starcoder2Loader {
2031 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2032 Ok(vec![
2033 Regex::new(r"lm_head\.(weight|bias)$")?,
2034 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2036 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2037 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2038 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2039 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2041 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2042 ])
2043 }
2044 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2045 self.isq_layer_regexes(config)
2046 }
2047}
2048
2049impl DeviceMappedModelLoader for Starcoder2Loader {
2050 fn mapped_max_act_size_elems(
2051 &self,
2052 config: &str,
2053 params: &AutoDeviceMapParams,
2054 ) -> Result<usize> {
2055 let AutoDeviceMapParams::Text {
2056 max_seq_len,
2057 max_batch_size,
2058 } = params
2059 else {
2060 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2061 };
2062
2063 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2064
2065 Ok(
2066 max_batch_size
2067 * cfg.num_attention_heads
2068 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2069 )
2070 }
2071 fn non_mapped_max_act_size_elems(
2072 &self,
2073 _config: &str,
2074 _params: &AutoDeviceMapParams,
2075 ) -> Result<usize> {
2076 Ok(0)
2077 }
2078
2079 fn non_mapped_size_in_bytes(
2080 &self,
2081 config: &str,
2082 dtype: DType,
2083 weight_pack_factor: usize,
2084 _matformer_config: Option<&MatformerSliceConfig>,
2085 ) -> Result<usize> {
2086 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2087
2088 let elems = {
2089 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2090 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2092 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2093 } else {
2094 0
2095 };
2096 let norm = cfg.hidden_size + cfg.hidden_size;
2097 embed_tokens + lm_head + norm
2098 };
2099 Ok(elems * dtype.size_in_bytes())
2100 }
2101
2102 fn layer_sizes_in_bytes(
2103 &self,
2104 config: &str,
2105 dtype: DType,
2106 weight_pack_factor: usize,
2107 _matformer_config: Option<&MatformerSliceConfig>,
2108 ) -> Result<Vec<usize>> {
2109 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2110
2111 let per_layer_elems = {
2112 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2113 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2114
2115 let size_in = cfg.hidden_size;
2116 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2117 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2118 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2119 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2120 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2121 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2122
2123 let h_size = cfg.hidden_size;
2124 let i_size = cfg.intermediate_size;
2125 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2126 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2127
2128 input_layernorm
2129 + post_attention_layernorm
2130 + q_proj
2131 + k_proj
2132 + v_proj
2133 + o_proj
2134 + fc1
2135 + fc2
2136 };
2137 Ok(vec![
2138 per_layer_elems * dtype.size_in_bytes();
2139 cfg.num_hidden_layers
2140 ])
2141 }
2142
2143 fn num_layers(&self, config: &str) -> Result<usize> {
2144 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2145
2146 Ok(cfg.num_hidden_layers)
2147 }
2148
2149 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2150 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2151
2152 let cfg = ModelConfigMetadata {
2153 max_seq_len: cfg.max_position_embeddings,
2154 num_layers: cfg.num_hidden_layers,
2155 hidden_size: cfg.hidden_size,
2156 num_kv_heads: cfg.num_key_value_heads,
2157 num_attn_heads: cfg.num_attention_heads,
2158 sliding_window: cfg.sliding_window,
2159 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2160 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2161 };
2162
2163 Ok(Box::new(cfg))
2164 }
2165}
2166
2167pub struct Phi3_5MoELoader;
2173
2174impl NormalModelLoader for Phi3_5MoELoader {
2175 fn load(
2176 &self,
2177 config: &str,
2178 vb: ShardedVarBuilder,
2179 normal_loading_metadata: NormalLoadingMetadata,
2180 attention_mechanism: AttentionImplementation,
2181 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2182 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2183
2184 Ok(Box::new(models::phi3_5_moe::Model::new(
2185 &cfg,
2186 vb,
2187 self.is_gptx(config)?,
2188 normal_loading_metadata,
2189 attention_mechanism,
2190 )?))
2191 }
2192 fn load_xlora(
2193 &self,
2194 config: &str,
2195 vb: ShardedVarBuilder,
2196 lora_config: &[((String, String), LoraConfig)],
2197 xlora_config: Option<XLoraConfig>,
2198 xlora_ordering: Ordering,
2199 normal_loading_metadata: NormalLoadingMetadata,
2200 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2201 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2202 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2203
2204 Ok(Box::new(xlora_models::XLoraPhi3::new(
2205 &cfg,
2206 vb,
2207 lora_config,
2208 xlora_config,
2209 xlora_ordering,
2210 self.is_gptx(config)?,
2211 normal_loading_metadata,
2212 preload_adapters,
2213 )?))
2214 }
2215 fn is_gptx(&self, _: &str) -> Result<bool> {
2216 Ok(true)
2217 }
2218 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2219 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2220
2221 Ok(Box::new(cfg))
2222 }
2223}
2224
2225impl IsqModelLoader for Phi3_5MoELoader {
2226 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2227 Ok(vec![
2228 Regex::new(r"lm_head\.(weight|bias)$")?,
2229 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2231 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2232 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2233 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2234 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2236 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2237 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2238 ])
2239 }
2240 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2241 self.isq_layer_regexes(config)
2242 }
2243
2244 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2245 Ok(vec![
2246 Regex::new(r"lm_head\.(weight|bias)$")?,
2247 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2249 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2250 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2251 ])
2252 }
2253 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2254 self.isq_layer_regexes_moqe(config)
2255 }
2256}
2257
2258impl DeviceMappedModelLoader for Phi3_5MoELoader {
2259 fn mapped_max_act_size_elems(
2260 &self,
2261 config: &str,
2262 params: &AutoDeviceMapParams,
2263 ) -> Result<usize> {
2264 let AutoDeviceMapParams::Text {
2265 max_seq_len,
2266 max_batch_size,
2267 } = params
2268 else {
2269 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2270 };
2271
2272 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2273
2274 Ok(
2275 max_batch_size
2276 * cfg.num_attention_heads
2277 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2278 )
2279 }
2280 fn non_mapped_max_act_size_elems(
2281 &self,
2282 _config: &str,
2283 _params: &AutoDeviceMapParams,
2284 ) -> Result<usize> {
2285 Ok(0)
2286 }
2287
2288 fn non_mapped_size_in_bytes(
2289 &self,
2290 config: &str,
2291 dtype: DType,
2292 weight_pack_factor: usize,
2293 _matformer_config: Option<&MatformerSliceConfig>,
2294 ) -> Result<usize> {
2295 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2296
2297 let elems = {
2298 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2299 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2301 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2302 } else {
2303 0
2304 };
2305 let norm = cfg.hidden_size;
2306 embed_tokens + lm_head + norm
2307 };
2308 Ok(elems * dtype.size_in_bytes())
2309 }
2310
2311 fn layer_sizes_in_bytes(
2312 &self,
2313 config: &str,
2314 dtype: DType,
2315 weight_pack_factor: usize,
2316 _matformer_config: Option<&MatformerSliceConfig>,
2317 ) -> Result<Vec<usize>> {
2318 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2319
2320 let per_layer_elems = {
2321 let input_layernorm = cfg.hidden_size;
2322 let post_attention_layernorm = cfg.hidden_size;
2323
2324 let size_in = cfg.hidden_size;
2325 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2326 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2327 let q_proj =
2328 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2329 let k_proj =
2330 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2331 let v_proj =
2332 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2333 let o_proj =
2334 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2335
2336 let moe_block = {
2337 let gate = cfg.hidden_size * cfg.num_local_experts;
2338 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2340 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2341 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2342 gate + cfg.num_local_experts * w1
2343 + cfg.num_local_experts * w2
2344 + cfg.num_local_experts * w3
2345 };
2346
2347 input_layernorm
2348 + post_attention_layernorm
2349 + q_proj
2350 + k_proj
2351 + v_proj
2352 + o_proj
2353 + moe_block
2354 };
2355 Ok(vec![
2356 per_layer_elems * dtype.size_in_bytes();
2357 cfg.num_hidden_layers
2358 ])
2359 }
2360
2361 fn num_layers(&self, config: &str) -> Result<usize> {
2362 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2363
2364 Ok(cfg.num_hidden_layers)
2365 }
2366
2367 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2368 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2369
2370 let cfg = ModelConfigMetadata {
2371 max_seq_len: cfg.max_position_embeddings,
2372 num_layers: cfg.num_hidden_layers,
2373 hidden_size: cfg.hidden_size,
2374 num_kv_heads: cfg.num_key_value_heads,
2375 num_attn_heads: cfg.num_attention_heads,
2376 sliding_window: cfg.sliding_window,
2377 k_head_dim: cfg.head_dim(),
2378 v_head_dim: cfg.head_dim(),
2379 };
2380
2381 Ok(Box::new(cfg))
2382 }
2383}
2384
2385pub struct DeepSeekV2Loader;
2389
2390impl NormalModelLoader for DeepSeekV2Loader {
2391 fn load(
2392 &self,
2393 config: &str,
2394 vb: ShardedVarBuilder,
2395 normal_loading_metadata: NormalLoadingMetadata,
2396 attention_mechanism: AttentionImplementation,
2397 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2398 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2399
2400 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2401 &cfg,
2402 vb,
2403 self.is_gptx(config)?,
2404 normal_loading_metadata,
2405 attention_mechanism,
2406 )?))
2407 }
2408 fn load_xlora(
2409 &self,
2410 _config: &str,
2411 _vb: ShardedVarBuilder,
2412 _lora_config: &[((String, String), LoraConfig)],
2413 _xlora_config: Option<XLoraConfig>,
2414 _xlora_ordering: Ordering,
2415 _normal_loading_metadata: NormalLoadingMetadata,
2416 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2417 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2418 todo!()
2419 }
2420 fn is_gptx(&self, _: &str) -> Result<bool> {
2421 Ok(true)
2422 }
2423 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2424 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2425 Ok(Box::new(cfg))
2426 }
2427}
2428
2429impl IsqModelLoader for DeepSeekV2Loader {
2430 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2431 let mut data = vec![
2432 Regex::new(r"lm_head\.(weight|bias)$")?,
2433 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2435 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2436 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2437 ];
2438 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2439 if cfg.q_lora_rank.is_some() {
2440 data.extend(vec![
2441 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2442 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2443 ]);
2444 } else {
2445 data.push(Regex::new(
2446 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2447 )?);
2448 }
2449 for layer_idx in 0..cfg.num_hidden_layers {
2450 if cfg.n_routed_experts.is_some()
2451 && layer_idx >= cfg.first_k_dense_replace
2452 && layer_idx % cfg.moe_layer_freq == 0
2453 {
2454 for i in 0..cfg.n_routed_experts.unwrap() {
2455 data.extend(vec![
2456 Regex::new(&format!(
2457 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2458 ))?,
2459 Regex::new(&format!(
2460 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2461 ))?,
2462 Regex::new(&format!(
2463 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2464 ))?,
2465 ]);
2466 }
2467 if cfg.n_shared_experts.is_some() {
2468 data.extend(vec![
2469 Regex::new(&format!(
2470 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2471 ))?,
2472 Regex::new(&format!(
2473 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2474 ))?,
2475 Regex::new(&format!(
2476 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2477 ))?,
2478 ]);
2479 }
2480 } else {
2481 data.extend(vec![
2482 Regex::new(&format!(
2483 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2484 ))?,
2485 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2486 Regex::new(&format!(
2487 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2488 ))?,
2489 ]);
2490 };
2491 }
2492 Ok(data)
2493 }
2494 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2495 self.isq_layer_regexes(config)
2496 }
2497
2498 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2499 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2500 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2501 for layer_idx in 0..cfg.num_hidden_layers {
2502 if cfg.n_routed_experts.is_some()
2503 && layer_idx >= cfg.first_k_dense_replace
2504 && layer_idx % cfg.moe_layer_freq == 0
2505 {
2506 for i in 0..cfg.n_routed_experts.unwrap() {
2507 data.extend(vec![
2508 Regex::new(&format!(
2509 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2510 ))?,
2511 Regex::new(&format!(
2512 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2513 ))?,
2514 Regex::new(&format!(
2515 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2516 ))?,
2517 ]);
2518 }
2519 if cfg.n_shared_experts.is_some() {
2520 data.extend(vec![
2521 Regex::new(&format!(
2522 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2523 ))?,
2524 Regex::new(&format!(
2525 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2526 ))?,
2527 Regex::new(&format!(
2528 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2529 ))?,
2530 ]);
2531 }
2532 } else {
2533 data.extend(vec![
2534 Regex::new(&format!(
2535 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2536 ))?,
2537 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2538 Regex::new(&format!(
2539 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2540 ))?,
2541 ]);
2542 };
2543 }
2544 Ok(data)
2545 }
2546 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2547 self.isq_layer_regexes_moqe(config)
2548 }
2549}
2550
2551impl DeviceMappedModelLoader for DeepSeekV2Loader {
2552 fn mapped_max_act_size_elems(
2553 &self,
2554 config: &str,
2555 params: &AutoDeviceMapParams,
2556 ) -> Result<usize> {
2557 let AutoDeviceMapParams::Text {
2558 max_seq_len,
2559 max_batch_size,
2560 } = params
2561 else {
2562 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2563 };
2564
2565 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2566
2567 Ok(
2568 max_batch_size
2569 * cfg.num_attention_heads
2570 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2571 )
2572 }
2573 fn non_mapped_max_act_size_elems(
2574 &self,
2575 _config: &str,
2576 _params: &AutoDeviceMapParams,
2577 ) -> Result<usize> {
2578 Ok(0)
2579 }
2580
2581 fn non_mapped_size_in_bytes(
2582 &self,
2583 config: &str,
2584 dtype: DType,
2585 weight_pack_factor: usize,
2586 _matformer_config: Option<&MatformerSliceConfig>,
2587 ) -> Result<usize> {
2588 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2589 let elems = {
2590 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2591 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2593 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2594 } else {
2595 0
2596 };
2597 let norm = cfg.hidden_size;
2598 embed_tokens + lm_head + norm
2599 };
2600 Ok(elems * dtype.size_in_bytes())
2601 }
2602
2603 fn layer_sizes_in_bytes(
2604 &self,
2605 config: &str,
2606 dtype: DType,
2607 weight_pack_factor: usize,
2608 _matformer_config: Option<&MatformerSliceConfig>,
2609 ) -> Result<Vec<usize>> {
2610 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2611 let mut per_layer_elems = Vec::new();
2612
2613 for layer_idx in 0..cfg.num_hidden_layers {
2614 let input_layernorm = cfg.hidden_size;
2615 let post_attention_layernorm = cfg.hidden_size;
2616
2617 let q_proj = match cfg.q_lora_rank {
2618 Some(lora_rank) => {
2619 let a = cfg.hidden_size * lora_rank;
2620 let norm = lora_rank;
2621 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2622 a + norm + b
2623 }
2624 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2625 };
2626 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2627 / weight_pack_factor
2628 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2629 let kv_a_layernorm = cfg.kv_lora_rank;
2630 let kv_b_proj = cfg.kv_lora_rank
2631 * cfg.num_attention_heads
2632 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2633 / weight_pack_factor;
2634 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2635 / weight_pack_factor
2636 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2637
2638 let moe_block = {
2639 let mut sum = 0;
2640 if cfg.n_routed_experts.is_some()
2641 && layer_idx >= cfg.first_k_dense_replace
2642 && layer_idx % cfg.moe_layer_freq == 0
2643 {
2644 let h_size = cfg.hidden_size;
2645 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2646 * cfg.n_routed_experts.unwrap();
2647 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2648 * cfg.n_routed_experts.unwrap();
2649 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2650 * cfg.n_routed_experts.unwrap();
2651 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2652 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2653 / weight_pack_factor;
2654 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2655 / weight_pack_factor;
2656 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2657 / weight_pack_factor;
2658 gate_proj + up_proj + down_proj
2659 } else {
2660 0
2661 };
2662 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2663 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2664 } else {
2665 let h_size = cfg.hidden_size;
2666 let i_size = cfg.intermediate_size;
2667 let gate_proj = h_size * i_size / weight_pack_factor;
2668 let up_proj = h_size * i_size / weight_pack_factor;
2669 let down_proj = i_size * h_size / weight_pack_factor;
2670 sum += gate_proj + up_proj + down_proj;
2671 }
2672 sum
2673 };
2674
2675 per_layer_elems.push(
2676 input_layernorm
2677 + post_attention_layernorm
2678 + q_proj
2679 + kv_a_layernorm
2680 + kv_a_proj_with_mqa
2681 + kv_b_proj
2682 + o_proj
2683 + moe_block,
2684 );
2685 }
2686
2687 Ok(per_layer_elems
2688 .into_iter()
2689 .map(|x| x * dtype.size_in_bytes())
2690 .collect())
2691 }
2692
2693 fn num_layers(&self, config: &str) -> Result<usize> {
2694 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2695 Ok(cfg.num_hidden_layers)
2696 }
2697
2698 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2699 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2700
2701 let cfg = ModelConfigMetadata {
2702 max_seq_len: cfg.max_position_embeddings,
2703 num_layers: cfg.num_hidden_layers,
2704 hidden_size: cfg.hidden_size,
2705 num_kv_heads: cfg.num_attention_heads,
2706 num_attn_heads: cfg.num_attention_heads,
2707 sliding_window: None,
2708 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2709 v_head_dim: cfg.v_head_dim,
2710 };
2711
2712 Ok(Box::new(cfg))
2713 }
2714}
2715
2716pub struct DeepSeekV3Loader;
2720
2721impl NormalModelLoader for DeepSeekV3Loader {
2722 fn load(
2723 &self,
2724 config: &str,
2725 vb: ShardedVarBuilder,
2726 normal_loading_metadata: NormalLoadingMetadata,
2727 attention_mechanism: AttentionImplementation,
2728 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2729 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2730 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2731 &cfg,
2732 vb,
2733 self.is_gptx(config)?,
2734 normal_loading_metadata,
2735 attention_mechanism,
2736 )?))
2737 }
2738 fn load_xlora(
2739 &self,
2740 _config: &str,
2741 _vb: ShardedVarBuilder,
2742 _lora_config: &[((String, String), LoraConfig)],
2743 _xlora_config: Option<XLoraConfig>,
2744 _xlora_ordering: Ordering,
2745 _normal_loading_metadata: NormalLoadingMetadata,
2746 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2747 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2748 todo!()
2749 }
2750 fn is_gptx(&self, _: &str) -> Result<bool> {
2751 Ok(true)
2752 }
2753 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2754 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2755 Ok(Box::new(cfg))
2756 }
2757}
2758
2759impl IsqModelLoader for DeepSeekV3Loader {
2760 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2761 let mut data = vec![
2762 Regex::new(r"lm_head\.(weight|bias)$")?,
2763 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2765 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2766 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2767 ];
2768 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2769 if cfg.q_lora_rank.is_some() {
2770 data.extend(vec![
2771 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2772 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2773 ]);
2774 } else {
2775 data.push(Regex::new(
2776 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2777 )?);
2778 }
2779 for layer_idx in 0..cfg.num_hidden_layers {
2780 if cfg.n_routed_experts.is_some()
2781 && layer_idx >= cfg.first_k_dense_replace
2782 && layer_idx % cfg.moe_layer_freq == 0
2783 {
2784 for i in 0..cfg.n_routed_experts.unwrap() {
2785 data.extend(vec![
2786 Regex::new(&format!(
2787 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2788 ))?,
2789 Regex::new(&format!(
2790 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2791 ))?,
2792 Regex::new(&format!(
2793 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2794 ))?,
2795 ]);
2796 }
2797 if cfg.n_shared_experts.is_some() {
2798 data.extend(vec![
2799 Regex::new(&format!(
2800 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2801 ))?,
2802 Regex::new(&format!(
2803 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2804 ))?,
2805 Regex::new(&format!(
2806 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2807 ))?,
2808 ]);
2809 }
2810 } else {
2811 data.extend(vec![
2812 Regex::new(&format!(
2813 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2814 ))?,
2815 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2816 Regex::new(&format!(
2817 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2818 ))?,
2819 ]);
2820 };
2821 }
2822 Ok(data)
2823 }
2824 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2825 self.isq_layer_regexes(config)
2826 }
2827
2828 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2829 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2830 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2831 for layer_idx in 0..cfg.num_hidden_layers {
2832 if cfg.n_routed_experts.is_some()
2833 && layer_idx >= cfg.first_k_dense_replace
2834 && layer_idx % cfg.moe_layer_freq == 0
2835 {
2836 for i in 0..cfg.n_routed_experts.unwrap() {
2837 data.extend(vec![
2838 Regex::new(&format!(
2839 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2840 ))?,
2841 Regex::new(&format!(
2842 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2843 ))?,
2844 Regex::new(&format!(
2845 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2846 ))?,
2847 ]);
2848 }
2849 if cfg.n_shared_experts.is_some() {
2850 data.extend(vec![
2851 Regex::new(&format!(
2852 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2853 ))?,
2854 Regex::new(&format!(
2855 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2856 ))?,
2857 Regex::new(&format!(
2858 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2859 ))?,
2860 ]);
2861 }
2862 } else {
2863 data.extend(vec![
2864 Regex::new(&format!(
2865 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2866 ))?,
2867 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2868 Regex::new(&format!(
2869 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2870 ))?,
2871 ]);
2872 };
2873 }
2874 Ok(data)
2875 }
2876 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2877 self.isq_layer_regexes_moqe(config)
2878 }
2879}
2880
2881impl DeviceMappedModelLoader for DeepSeekV3Loader {
2882 fn mapped_max_act_size_elems(
2883 &self,
2884 config: &str,
2885 params: &AutoDeviceMapParams,
2886 ) -> Result<usize> {
2887 let AutoDeviceMapParams::Text {
2888 max_seq_len,
2889 max_batch_size,
2890 } = params
2891 else {
2892 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2893 };
2894
2895 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2896
2897 Ok(
2898 max_batch_size
2899 * cfg.num_attention_heads
2900 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2901 )
2902 }
2903 fn non_mapped_max_act_size_elems(
2904 &self,
2905 _config: &str,
2906 _params: &AutoDeviceMapParams,
2907 ) -> Result<usize> {
2908 Ok(0)
2909 }
2910
2911 fn non_mapped_size_in_bytes(
2912 &self,
2913 config: &str,
2914 dtype: DType,
2915 weight_pack_factor: usize,
2916 _matformer_config: Option<&MatformerSliceConfig>,
2917 ) -> Result<usize> {
2918 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2919 let elems = {
2920 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2921 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2923 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2924 } else {
2925 0
2926 };
2927 let norm = cfg.hidden_size;
2928 embed_tokens + lm_head + norm
2929 };
2930 Ok(elems * dtype.size_in_bytes())
2931 }
2932
2933 fn layer_sizes_in_bytes(
2934 &self,
2935 config: &str,
2936 dtype: DType,
2937 weight_pack_factor: usize,
2938 _matformer_config: Option<&MatformerSliceConfig>,
2939 ) -> Result<Vec<usize>> {
2940 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2941 let mut per_layer_elems = Vec::new();
2942
2943 for layer_idx in 0..cfg.num_hidden_layers {
2944 let input_layernorm = cfg.hidden_size;
2945 let post_attention_layernorm = cfg.hidden_size;
2946
2947 let q_proj = match cfg.q_lora_rank {
2948 Some(lora_rank) => {
2949 let a = cfg.hidden_size * lora_rank;
2950 let norm = lora_rank;
2951 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2952 a + norm + b
2953 }
2954 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2955 };
2956 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2957 / weight_pack_factor
2958 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2959 let kv_a_layernorm = cfg.kv_lora_rank;
2960 let kv_b_proj = cfg.kv_lora_rank
2961 * cfg.num_attention_heads
2962 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2963 / weight_pack_factor;
2964 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2965 / weight_pack_factor
2966 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2967
2968 let moe_block = {
2969 let mut sum = 0;
2970 if cfg.n_routed_experts.is_some()
2971 && layer_idx >= cfg.first_k_dense_replace
2972 && layer_idx % cfg.moe_layer_freq == 0
2973 {
2974 let h_size = cfg.hidden_size;
2975 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2976 * cfg.n_routed_experts.unwrap();
2977 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2978 * cfg.n_routed_experts.unwrap();
2979 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2980 * cfg.n_routed_experts.unwrap();
2981 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2982 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2983 / weight_pack_factor;
2984 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2985 / weight_pack_factor;
2986 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2987 / weight_pack_factor;
2988 gate_proj + up_proj + down_proj
2989 } else {
2990 0
2991 };
2992 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2993 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2994 } else {
2995 let h_size = cfg.hidden_size;
2996 let i_size = cfg.intermediate_size;
2997 let gate_proj = h_size * i_size / weight_pack_factor;
2998 let up_proj = h_size * i_size / weight_pack_factor;
2999 let down_proj = i_size * h_size / weight_pack_factor;
3000 sum += gate_proj + up_proj + down_proj;
3001 }
3002 sum
3003 };
3004
3005 per_layer_elems.push(
3006 input_layernorm
3007 + post_attention_layernorm
3008 + q_proj
3009 + kv_a_layernorm
3010 + kv_a_proj_with_mqa
3011 + kv_b_proj
3012 + o_proj
3013 + moe_block,
3014 );
3015 }
3016
3017 Ok(per_layer_elems
3018 .into_iter()
3019 .map(|x| x * dtype.size_in_bytes())
3020 .collect())
3021 }
3022
3023 fn num_layers(&self, config: &str) -> Result<usize> {
3024 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3025 Ok(cfg.num_hidden_layers)
3026 }
3027
3028 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3029 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3030
3031 let cfg = ModelConfigMetadata {
3032 max_seq_len: cfg.max_position_embeddings,
3033 num_layers: cfg.num_hidden_layers,
3034 hidden_size: cfg.hidden_size,
3035 num_kv_heads: cfg.num_attention_heads,
3036 num_attn_heads: cfg.num_attention_heads,
3037 sliding_window: None,
3038 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3039 v_head_dim: cfg.v_head_dim,
3040 };
3041
3042 Ok(Box::new(cfg))
3043 }
3044}
3045
3046pub struct Qwen3Loader;
3050
3051impl NormalModelLoader for Qwen3Loader {
3052 fn load(
3053 &self,
3054 config: &str,
3055 vb: ShardedVarBuilder,
3056 normal_loading_metadata: NormalLoadingMetadata,
3057 attention_mechanism: AttentionImplementation,
3058 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3059 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3060
3061 Ok(Box::new(models::qwen3::Model::new(
3062 &cfg,
3063 vb,
3064 self.is_gptx(config)?,
3065 normal_loading_metadata,
3066 attention_mechanism,
3067 )?))
3068 }
3069 fn load_xlora(
3070 &self,
3071 _config: &str,
3072 _vb: ShardedVarBuilder,
3073 _lora_config: &[((String, String), LoraConfig)],
3074 _xlora_config: Option<XLoraConfig>,
3075 _xlora_ordering: Ordering,
3076 _normal_loading_metadata: NormalLoadingMetadata,
3077 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3078 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3079 todo!()
3080 }
3081 fn is_gptx(&self, _: &str) -> Result<bool> {
3082 Ok(true)
3083 }
3084 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3085 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3086
3087 Ok(Box::new(cfg))
3088 }
3089}
3090
3091impl IsqModelLoader for Qwen3Loader {
3092 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3093 Ok(vec![
3094 Regex::new(r"lm_head\.(weight|bias)$")?,
3095 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3097 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3098 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3099 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3100 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3102 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3103 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3104 ])
3105 }
3106 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3107 self.isq_layer_regexes(config)
3108 }
3109}
3110
3111impl DeviceMappedModelLoader for Qwen3Loader {
3112 fn mapped_max_act_size_elems(
3113 &self,
3114 config: &str,
3115 params: &AutoDeviceMapParams,
3116 ) -> Result<usize> {
3117 let AutoDeviceMapParams::Text {
3118 max_seq_len,
3119 max_batch_size,
3120 } = params
3121 else {
3122 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3123 };
3124
3125 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3126
3127 Ok(
3128 max_batch_size
3129 * cfg.num_attention_heads
3130 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3131 )
3132 }
3133 fn non_mapped_max_act_size_elems(
3134 &self,
3135 _config: &str,
3136 _params: &AutoDeviceMapParams,
3137 ) -> Result<usize> {
3138 Ok(0)
3139 }
3140
3141 fn non_mapped_size_in_bytes(
3142 &self,
3143 config: &str,
3144 dtype: DType,
3145 weight_pack_factor: usize,
3146 _matformer_config: Option<&MatformerSliceConfig>,
3147 ) -> Result<usize> {
3148 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3149 let elems = {
3150 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3151 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3153 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3154 } else {
3155 0
3156 };
3157 let norm = cfg.hidden_size;
3158 embed_tokens + lm_head + norm
3159 };
3160 Ok(elems * dtype.size_in_bytes())
3161 }
3162
3163 fn layer_sizes_in_bytes(
3164 &self,
3165 config: &str,
3166 dtype: DType,
3167 weight_pack_factor: usize,
3168 _matformer_config: Option<&MatformerSliceConfig>,
3169 ) -> Result<Vec<usize>> {
3170 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3171 let per_layer_elems = {
3172 let input_layernorm = cfg.hidden_size;
3173 let post_attention_layernorm = cfg.hidden_size;
3174
3175 let size_in = cfg.hidden_size;
3176 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3177 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3178 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3179 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3180 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3181 let o_proj = size_q * size_in / weight_pack_factor;
3182
3183 let h_size = cfg.hidden_size;
3184 let i_size = cfg.intermediate_size;
3185 let gate_proj = h_size * i_size / weight_pack_factor;
3186 let up_proj = h_size * i_size / weight_pack_factor;
3187 let down_proj = i_size * h_size / weight_pack_factor;
3188
3189 let q_norm = cfg.head_dim();
3190 let k_norm = cfg.head_dim();
3191
3192 input_layernorm
3193 + post_attention_layernorm
3194 + q_proj
3195 + k_proj
3196 + v_proj
3197 + o_proj
3198 + gate_proj
3199 + up_proj
3200 + down_proj
3201 + q_norm
3202 + k_norm
3203 };
3204 Ok(vec![
3205 per_layer_elems * dtype.size_in_bytes();
3206 cfg.num_hidden_layers
3207 ])
3208 }
3209
3210 fn num_layers(&self, config: &str) -> Result<usize> {
3211 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3212 Ok(cfg.num_hidden_layers)
3213 }
3214
3215 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3216 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3217
3218 let cfg = ModelConfigMetadata {
3219 max_seq_len: cfg.max_position_embeddings,
3220 num_layers: cfg.num_hidden_layers,
3221 hidden_size: cfg.hidden_size,
3222 num_kv_heads: cfg.num_key_value_heads,
3223 num_attn_heads: cfg.num_attention_heads,
3224 sliding_window: cfg.sliding_window,
3225 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3226 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3227 };
3228
3229 Ok(Box::new(cfg))
3230 }
3231}
3232
3233pub struct GLM4Loader;
3237
3238impl NormalModelLoader for GLM4Loader {
3239 fn load(
3240 &self,
3241 config: &str,
3242 vb: ShardedVarBuilder,
3243 normal_loading_metadata: NormalLoadingMetadata,
3244 attention_mechanism: AttentionImplementation,
3245 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3246 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3247
3248 Ok(Box::new(models::glm4::Model::new(
3249 &cfg,
3250 vb,
3251 self.is_gptx(config)?,
3252 normal_loading_metadata,
3253 attention_mechanism,
3254 )?))
3255 }
3256 fn load_xlora(
3257 &self,
3258 _config: &str,
3259 _vb: ShardedVarBuilder,
3260 _lora_config: &[((String, String), LoraConfig)],
3261 _xlora_config: Option<XLoraConfig>,
3262 _xlora_ordering: Ordering,
3263 _normal_loading_metadata: NormalLoadingMetadata,
3264 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3265 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3266 todo!()
3267 }
3268 fn is_gptx(&self, _: &str) -> Result<bool> {
3269 Ok(true)
3270 }
3271 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3272 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3273
3274 Ok(Box::new(cfg))
3275 }
3276}
3277
3278impl IsqModelLoader for GLM4Loader {
3279 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3280 Ok(vec![
3281 Regex::new(r"lm_head\.(weight|bias)$")?,
3282 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3284 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3285 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3286 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3287 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3289 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3290 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3291 ])
3292 }
3293 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3294 self.isq_layer_regexes(config)
3295 }
3296}
3297
3298impl DeviceMappedModelLoader for GLM4Loader {
3299 fn mapped_max_act_size_elems(
3300 &self,
3301 config: &str,
3302 params: &AutoDeviceMapParams,
3303 ) -> Result<usize> {
3304 let AutoDeviceMapParams::Text {
3305 max_seq_len,
3306 max_batch_size,
3307 } = params
3308 else {
3309 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3310 };
3311
3312 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3313
3314 Ok(
3315 max_batch_size
3316 * cfg.num_attention_heads
3317 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3318 )
3319 }
3320 fn non_mapped_max_act_size_elems(
3321 &self,
3322 _config: &str,
3323 _params: &AutoDeviceMapParams,
3324 ) -> Result<usize> {
3325 Ok(0)
3326 }
3327
3328 fn non_mapped_size_in_bytes(
3329 &self,
3330 config: &str,
3331 dtype: DType,
3332 weight_pack_factor: usize,
3333 _matformer_config: Option<&MatformerSliceConfig>,
3334 ) -> Result<usize> {
3335 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3336 let elems = {
3337 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3338 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3340 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3341 } else {
3342 0
3343 };
3344 let norm = cfg.hidden_size;
3345 embed_tokens + lm_head + norm
3346 };
3347 Ok(elems * dtype.size_in_bytes())
3348 }
3349
3350 fn layer_sizes_in_bytes(
3351 &self,
3352 config: &str,
3353 dtype: DType,
3354 weight_pack_factor: usize,
3355 _matformer_config: Option<&MatformerSliceConfig>,
3356 ) -> Result<Vec<usize>> {
3357 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358 let per_layer_elems = {
3359 let input_layernorm = cfg.hidden_size;
3360 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3363 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3364 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3365 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3366 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3367 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3368 let o_proj = size_q * size_in / weight_pack_factor;
3369
3370 let h_size = cfg.hidden_size;
3371 let i_size = cfg.intermediate_size;
3372 let gate_proj = h_size * i_size / weight_pack_factor;
3373 let up_proj = h_size * i_size / weight_pack_factor;
3374 let down_proj = i_size * h_size / weight_pack_factor;
3375
3376 input_layernorm
3377 + post_attention_layernorm
3378 + q_proj
3379 + k_proj
3380 + v_proj
3381 + o_proj
3382 + gate_proj
3383 + up_proj
3384 + down_proj
3385 };
3386 Ok(vec![
3387 per_layer_elems * dtype.size_in_bytes();
3388 cfg.num_hidden_layers
3389 ])
3390 }
3391
3392 fn num_layers(&self, config: &str) -> Result<usize> {
3393 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3394 Ok(cfg.num_hidden_layers)
3395 }
3396
3397 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3398 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3399
3400 let cfg = ModelConfigMetadata {
3401 max_seq_len: cfg.max_position_embeddings,
3402 num_layers: cfg.num_hidden_layers,
3403 hidden_size: cfg.hidden_size,
3404 num_kv_heads: cfg.num_key_value_heads,
3405 num_attn_heads: cfg.num_attention_heads,
3406 sliding_window: cfg.sliding_window,
3407 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3408 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3409 };
3410
3411 Ok(Box::new(cfg))
3412 }
3413}
3414
3415pub struct Qwen3MoELoader;
3419
3420impl NormalModelLoader for Qwen3MoELoader {
3421 fn load(
3422 &self,
3423 config: &str,
3424 vb: ShardedVarBuilder,
3425 normal_loading_metadata: NormalLoadingMetadata,
3426 attention_mechanism: AttentionImplementation,
3427 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3428 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3429
3430 Ok(Box::new(models::qwen3_moe::Model::new(
3431 &cfg,
3432 vb,
3433 self.is_gptx(config)?,
3434 normal_loading_metadata,
3435 attention_mechanism,
3436 )?))
3437 }
3438 fn load_xlora(
3439 &self,
3440 _config: &str,
3441 _vb: ShardedVarBuilder,
3442 _lora_config: &[((String, String), LoraConfig)],
3443 _xlora_config: Option<XLoraConfig>,
3444 _xlora_ordering: Ordering,
3445 _normal_loading_metadata: NormalLoadingMetadata,
3446 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3447 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3448 todo!()
3449 }
3450 fn is_gptx(&self, _: &str) -> Result<bool> {
3451 Ok(true)
3452 }
3453 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3454 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3455
3456 Ok(Box::new(cfg))
3457 }
3458}
3459
3460impl IsqModelLoader for Qwen3MoELoader {
3461 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3462 Ok(vec![
3463 Regex::new(r"lm_head\.(weight|bias)$")?,
3464 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3466 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3467 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3468 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3469 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3471 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3472 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3473 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3475 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3476 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3477 ])
3478 }
3479 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3480 self.isq_layer_regexes(config)
3481 }
3482 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3483 self.isq_layer_regexes_moqe(config)
3484 }
3485}
3486
3487impl DeviceMappedModelLoader for Qwen3MoELoader {
3488 fn mapped_max_act_size_elems(
3489 &self,
3490 config: &str,
3491 params: &AutoDeviceMapParams,
3492 ) -> Result<usize> {
3493 let AutoDeviceMapParams::Text {
3494 max_seq_len,
3495 max_batch_size,
3496 } = params
3497 else {
3498 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3499 };
3500
3501 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3502
3503 Ok(
3504 max_batch_size
3505 * cfg.num_attention_heads
3506 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3507 )
3508 }
3509 fn non_mapped_max_act_size_elems(
3510 &self,
3511 _config: &str,
3512 _params: &AutoDeviceMapParams,
3513 ) -> Result<usize> {
3514 Ok(0)
3515 }
3516
3517 fn non_mapped_size_in_bytes(
3518 &self,
3519 config: &str,
3520 dtype: DType,
3521 weight_pack_factor: usize,
3522 _matformer_config: Option<&MatformerSliceConfig>,
3523 ) -> Result<usize> {
3524 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3525 let elems = {
3526 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3527 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3529 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3530 } else {
3531 0
3532 };
3533 let norm = cfg.hidden_size;
3534 embed_tokens + lm_head + norm
3535 };
3536 Ok(elems * dtype.size_in_bytes())
3537 }
3538
3539 fn layer_sizes_in_bytes(
3540 &self,
3541 config: &str,
3542 dtype: DType,
3543 weight_pack_factor: usize,
3544 _matformer_config: Option<&MatformerSliceConfig>,
3545 ) -> Result<Vec<usize>> {
3546 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3547
3548 let mut layer_sizes_in_bytes = Vec::new();
3549 for layer_idx in 0..cfg.num_hidden_layers {
3550 let input_layernorm = cfg.hidden_size;
3551 let post_attention_layernorm = cfg.hidden_size;
3552
3553 let size_in = cfg.hidden_size;
3554 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3555 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3556 let q_proj = size_in * size_q / weight_pack_factor;
3557 let k_proj = size_in * size_kv / weight_pack_factor;
3558 let v_proj = size_in * size_kv / weight_pack_factor;
3559 let o_proj = size_q * size_in / weight_pack_factor;
3560
3561 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3562 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3563 {
3564 let gate_size = cfg.hidden_size * cfg.num_experts;
3565 let expert_size = {
3566 let h_size = cfg.hidden_size;
3567 let i_size = cfg.moe_intermediate_size;
3568 let gate_proj = h_size * i_size / weight_pack_factor;
3569 let up_proj = h_size * i_size / weight_pack_factor;
3570 let down_proj = i_size * h_size / weight_pack_factor;
3571 gate_proj + up_proj + down_proj
3572 };
3573 expert_size * cfg.num_experts + gate_size
3574 } else {
3575 let h_size = cfg.hidden_size;
3576 let i_size = cfg.intermediate_size;
3577 let gate_proj = h_size * i_size / weight_pack_factor;
3578 let up_proj = h_size * i_size / weight_pack_factor;
3579 let down_proj = i_size * h_size / weight_pack_factor;
3580 gate_proj + up_proj + down_proj
3581 };
3582
3583 let q_norm = cfg.head_dim();
3584 let k_norm = cfg.head_dim();
3585
3586 let size_elems = input_layernorm
3587 + post_attention_layernorm
3588 + q_proj
3589 + k_proj
3590 + v_proj
3591 + o_proj
3592 + mlp_size
3593 + q_norm
3594 + k_norm;
3595
3596 let size_in_bytes = size_elems * dtype.size_in_bytes();
3597 layer_sizes_in_bytes.push(size_in_bytes);
3598 }
3599
3600 Ok(layer_sizes_in_bytes)
3601 }
3602
3603 fn num_layers(&self, config: &str) -> Result<usize> {
3604 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3605 Ok(cfg.num_hidden_layers)
3606 }
3607
3608 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3609 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3610
3611 let cfg = ModelConfigMetadata {
3612 max_seq_len: cfg.max_position_embeddings,
3613 num_layers: cfg.num_hidden_layers,
3614 hidden_size: cfg.hidden_size,
3615 num_kv_heads: cfg.num_key_value_heads,
3616 num_attn_heads: cfg.num_attention_heads,
3617 sliding_window: cfg.sliding_window,
3618 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3619 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3620 };
3621
3622 Ok(Box::new(cfg))
3623 }
3624}
3625
3626pub struct SmolLm3Loader;
3632
3633impl NormalModelLoader for SmolLm3Loader {
3634 fn load(
3635 &self,
3636 config: &str,
3637 vb: ShardedVarBuilder,
3638 normal_loading_metadata: NormalLoadingMetadata,
3639 attention_mechanism: AttentionImplementation,
3640 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3641 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3642
3643 Ok(Box::new(models::smollm3::SmolLm3::new(
3644 &cfg,
3645 vb,
3646 self.is_gptx(config)?,
3647 normal_loading_metadata,
3648 attention_mechanism,
3649 )?))
3650 }
3651 fn load_xlora(
3652 &self,
3653 _config: &str,
3654 _vb: ShardedVarBuilder,
3655 _lora_config: &[((String, String), LoraConfig)],
3656 _xlora_config: Option<XLoraConfig>,
3657 _xlora_ordering: Ordering,
3658 _normal_loading_metadata: NormalLoadingMetadata,
3659 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3660 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3661 todo!()
3662 }
3663 fn is_gptx(&self, _: &str) -> Result<bool> {
3664 Ok(true)
3665 }
3666 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3667 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3668 Ok(Box::new(cfg))
3669 }
3670}
3671
3672impl IsqModelLoader for SmolLm3Loader {
3673 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3674 Ok(vec![
3675 Regex::new(r"lm_head\.(weight|bias)$")?,
3676 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3678 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3679 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3680 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3681 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3683 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3684 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3685 ])
3686 }
3687 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3688 self.isq_layer_regexes(config)
3689 }
3690}
3691
3692impl DeviceMappedModelLoader for SmolLm3Loader {
3693 fn mapped_max_act_size_elems(
3694 &self,
3695 config: &str,
3696 params: &AutoDeviceMapParams,
3697 ) -> Result<usize> {
3698 let AutoDeviceMapParams::Text {
3699 max_seq_len,
3700 max_batch_size,
3701 } = params
3702 else {
3703 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3704 };
3705
3706 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3707
3708 Ok(
3709 max_batch_size
3710 * cfg.num_attention_heads
3711 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3712 )
3713 }
3714 fn non_mapped_max_act_size_elems(
3715 &self,
3716 _config: &str,
3717 _params: &AutoDeviceMapParams,
3718 ) -> Result<usize> {
3719 Ok(0)
3720 }
3721
3722 fn non_mapped_size_in_bytes(
3723 &self,
3724 config: &str,
3725 dtype: DType,
3726 weight_pack_factor: usize,
3727 _matformer_config: Option<&MatformerSliceConfig>,
3728 ) -> Result<usize> {
3729 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3730
3731 let elems = {
3732 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3733 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3735 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3736 } else {
3737 0
3738 };
3739 let norm = cfg.hidden_size;
3740 embed_tokens + lm_head + norm
3741 };
3742 Ok(elems * dtype.size_in_bytes())
3743 }
3744
3745 fn layer_sizes_in_bytes(
3746 &self,
3747 config: &str,
3748 dtype: DType,
3749 weight_pack_factor: usize,
3750 _matformer_config: Option<&MatformerSliceConfig>,
3751 ) -> Result<Vec<usize>> {
3752 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3753
3754 let per_layer_elems = {
3755 let input_layernorm = cfg.hidden_size;
3756 let post_attention_layernorm = cfg.hidden_size;
3757
3758 let size_in = cfg.hidden_size;
3759 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3760 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3761 let q_proj = size_in * size_q / weight_pack_factor;
3762 let k_proj = size_in * size_kv / weight_pack_factor;
3763 let v_proj = size_in * size_kv / weight_pack_factor;
3764 let o_proj = size_q * size_in / weight_pack_factor;
3765
3766 let h_size = cfg.hidden_size;
3767 let i_size = cfg.intermediate_size;
3768 let gate_proj = h_size * i_size / weight_pack_factor;
3769 let up_proj = h_size * i_size / weight_pack_factor;
3770 let down_proj = i_size * h_size / weight_pack_factor;
3771
3772 input_layernorm
3773 + post_attention_layernorm
3774 + q_proj
3775 + k_proj
3776 + v_proj
3777 + o_proj
3778 + gate_proj
3779 + up_proj
3780 + down_proj
3781 };
3782 Ok(vec![
3783 per_layer_elems * dtype.size_in_bytes();
3784 cfg.num_hidden_layers
3785 ])
3786 }
3787
3788 fn num_layers(&self, config: &str) -> Result<usize> {
3789 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3790
3791 Ok(cfg.num_hidden_layers)
3792 }
3793 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3794 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3795
3796 let cfg = ModelConfigMetadata {
3797 max_seq_len: cfg.max_position_embeddings,
3798 num_layers: cfg.num_hidden_layers,
3799 hidden_size: cfg.hidden_size,
3800 num_kv_heads: cfg.num_key_value_heads,
3801 num_attn_heads: cfg.num_attention_heads,
3802 sliding_window: None,
3803 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3804 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3805 };
3806
3807 Ok(Box::new(cfg))
3808 }
3809}