1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::matformer::MatformerSliceConfig;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "qwen3moe")]
175 Qwen3Moe,
176 #[serde(rename = "smollm3")]
177 SmolLm3,
178}
179
180impl NormalLoaderType {
182 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183 match name {
184 "MistralForCausalLM" => Ok(Self::Mistral),
185 "MixtralForCausalLM" => Ok(Self::Mixtral),
186 "GemmaForCausalLM" => Ok(Self::Gemma),
187 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
188 "PhiForCausalLM" => Ok(Self::Phi2),
189 "Phi3ForCausalLM" => Ok(Self::Phi3),
190 "LlamaForCausalLM" => Ok(Self::Llama),
191 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
192 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
193 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
194 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
195 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
196 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
197 "Glm4ForCausalLM" => Ok(Self::GLM4),
198 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
199 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
200 other => anyhow::bail!(
201 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
202 ),
203 }
204 }
205}
206
207impl FromStr for NormalLoaderType {
208 type Err = String;
209 fn from_str(s: &str) -> Result<Self, Self::Err> {
210 match s {
211 "mistral" => Ok(Self::Mistral),
212 "gemma" => Ok(Self::Gemma),
213 "mixtral" => Ok(Self::Mixtral),
214 "llama" => Ok(Self::Llama),
215 "phi2" => Ok(Self::Phi2),
216 "phi3" => Ok(Self::Phi3),
217 "qwen2" => Ok(Self::Qwen2),
218 "gemma2" => Ok(Self::Gemma2),
219 "starcoder2" => Ok(Self::Starcoder2),
220 "phi3.5moe" => Ok(Self::Phi3_5MoE),
221 "deepseekv2" => Ok(Self::DeepSeekV2),
222 "deepseekv3" => Ok(Self::DeepSeekV3),
223 "qwen3" => Ok(Self::Qwen3),
224 "glm4" => Ok(Self::GLM4),
225 "qwen3moe" => Ok(Self::Qwen3Moe),
226 "smollm3" => Ok(Self::SmolLm3),
227 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
228 }
229 }
230}
231
232impl Display for NormalLoaderType {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Self::Gemma => write!(f, "gemma"),
236 Self::Gemma2 => write!(f, "gemma2"),
237 Self::Llama => write!(f, "llama"),
238 Self::Mistral => write!(f, "mistral"),
239 Self::Mixtral => write!(f, "mixtral"),
240 Self::Phi2 => write!(f, "phi2"),
241 Self::Phi3 => write!(f, "phi3"),
242 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
243 Self::Qwen2 => write!(f, "qwen2"),
244 Self::Starcoder2 => write!(f, "starcoder2"),
245 Self::DeepSeekV2 => write!(f, "deepseekv2"),
246 Self::DeepSeekV3 => write!(f, "deepseekv3"),
247 Self::Qwen3 => write!(f, "qwen3"),
248 Self::GLM4 => write!(f, "glm4"),
249 Self::Qwen3Moe => write!(f, "qwen3moe"),
250 Self::SmolLm3 => write!(f, "smollm3"),
251 }
252 }
253}
254
255macro_rules! bias_if {
256 ($cond:expr, $size:expr) => {
257 if $cond {
258 $size
259 } else {
260 0
261 }
262 };
263}
264
265pub struct AutoNormalLoader;
267
268#[derive(Deserialize)]
269struct AutoNormalLoaderConfig {
270 architectures: Vec<String>,
271}
272
273impl AutoNormalLoader {
274 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
275 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
276 if auto_cfg.architectures.len() != 1 {
277 anyhow::bail!("Expected to have one name for `architectures` config field.")
278 }
279
280 let name = &auto_cfg.architectures[0];
281
282 let tp = NormalLoaderType::from_causal_lm_name(name)?;
283
284 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
285
286 match tp {
287 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
288 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
289 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
290 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
291 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
292 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
293 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
294 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
295 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
296 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
297 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
298 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
299 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
300 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
301 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
302 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
303 }
304 }
305}
306
307impl NormalModelLoader for AutoNormalLoader {
308 fn load(
309 &self,
310 config: &str,
311 vb: ShardedVarBuilder,
312 normal_loading_metadata: NormalLoadingMetadata,
313 attention_mechanism: AttentionImplementation,
314 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
315 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
316 }
317 fn load_xlora(
318 &self,
319 config: &str,
320 vb: ShardedVarBuilder,
321 lora_config: &[((String, String), LoraConfig)],
322 xlora_config: Option<XLoraConfig>,
323 xlora_ordering: Ordering,
324 normal_loading_metadata: NormalLoadingMetadata,
325 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
326 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327 Self::get_loader(config)?.load_xlora(
328 config,
329 vb,
330 lora_config,
331 xlora_config,
332 xlora_ordering,
333 normal_loading_metadata,
334 preload_adapters,
335 )
336 }
337 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
338 Self::get_loader(config)?.get_config_repr(config)
339 }
340 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
341 Self::get_loader(config)?.supports_paged_attention(config)
342 }
343 fn is_gptx(&self, config: &str) -> Result<bool> {
344 Self::get_loader(config)?.is_gptx(config)
345 }
346}
347
348impl IsqModelLoader for AutoNormalLoader {
349 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
350 Self::get_loader(config)?.immediate_isq_predicates(config)
351 }
352 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
353 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
354 }
355 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
356 Self::get_loader(config)?.isq_layer_regexes(config)
357 }
358 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
360 }
361}
362
363impl DeviceMappedModelLoader for AutoNormalLoader {
364 fn non_mapped_size_in_bytes(
365 &self,
366 config: &str,
367 dtype: DType,
368 weight_pack_factor: usize,
369 _matformer_config: Option<&MatformerSliceConfig>,
370 ) -> Result<usize> {
371 Self::get_loader(config)?.non_mapped_size_in_bytes(
372 config,
373 dtype,
374 weight_pack_factor,
375 _matformer_config,
376 )
377 }
378 fn num_layers(&self, config: &str) -> Result<usize> {
379 Self::get_loader(config)?.num_layers(config)
380 }
381 fn layer_sizes_in_bytes(
382 &self,
383 config: &str,
384 dtype: DType,
385 weight_pack_factor: usize,
386 _matformer_config: Option<&MatformerSliceConfig>,
387 ) -> Result<Vec<usize>> {
388 Self::get_loader(config)?.layer_sizes_in_bytes(
389 config,
390 dtype,
391 weight_pack_factor,
392 _matformer_config,
393 )
394 }
395 fn mapped_max_act_size_elems(
396 &self,
397 config: &str,
398 params: &super::AutoDeviceMapParams,
399 prompt_chunksize: usize,
400 ) -> Result<usize> {
401 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
402 }
403 fn non_mapped_max_act_size_elems(
404 &self,
405 _config: &str,
406 _params: &AutoDeviceMapParams,
407 ) -> Result<usize> {
408 Ok(0)
409 }
410 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
411 Self::get_loader(config)?.model_config(config)
412 }
413}
414
415pub struct MistralLoader;
418
419impl NormalModelLoader for MistralLoader {
420 fn load(
421 &self,
422 config: &str,
423 vb: ShardedVarBuilder,
424 normal_loading_metadata: NormalLoadingMetadata,
425 attention_mechanism: AttentionImplementation,
426 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
427 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
428 Ok(Box::new(models::mistral::Model::new(
429 &cfg,
430 vb,
431 self.is_gptx(config)?,
432 normal_loading_metadata,
433 attention_mechanism,
434 )?))
435 }
436 fn load_xlora(
437 &self,
438 config: &str,
439 vb: ShardedVarBuilder,
440 lora_config: &[((String, String), LoraConfig)],
441 xlora_config: Option<XLoraConfig>,
442 xlora_ordering: Ordering,
443 normal_loading_metadata: NormalLoadingMetadata,
444 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
446 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
447 Ok(Box::new(xlora_models::XLoraMistral::new(
448 &cfg,
449 vb,
450 lora_config,
451 xlora_config,
452 xlora_ordering,
453 self.is_gptx(config)?,
454 normal_loading_metadata,
455 preload_adapters,
456 )?))
457 }
458 fn is_gptx(&self, _: &str) -> Result<bool> {
459 Ok(true)
460 }
461 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
462 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
463 Ok(Box::new(cfg))
464 }
465}
466
467impl IsqModelLoader for MistralLoader {
468 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
469 Ok(vec![
470 Regex::new(r"lm_head\.(weight|bias)$")?,
471 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
473 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
474 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
475 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
476 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
478 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
479 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
480 ])
481 }
482 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
483 self.isq_layer_regexes(config)
484 }
485}
486
487impl DeviceMappedModelLoader for MistralLoader {
488 fn mapped_max_act_size_elems(
489 &self,
490 config: &str,
491 params: &AutoDeviceMapParams,
492 prompt_chunksize: usize,
493 ) -> Result<usize> {
494 let AutoDeviceMapParams::Text {
495 max_seq_len: _,
496 max_batch_size,
497 } = params
498 else {
499 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
500 };
501
502 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
503
504 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
505 }
506 fn non_mapped_max_act_size_elems(
507 &self,
508 _config: &str,
509 _params: &AutoDeviceMapParams,
510 ) -> Result<usize> {
511 Ok(0)
512 }
513
514 fn non_mapped_size_in_bytes(
515 &self,
516 config: &str,
517 dtype: DType,
518 weight_pack_factor: usize,
519 _matformer_config: Option<&MatformerSliceConfig>,
520 ) -> Result<usize> {
521 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
522
523 let elems = {
524 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
525 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
527 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
528 } else {
529 0
530 };
531 let norm = cfg.hidden_size;
532 embed_tokens + lm_head + norm
533 };
534 Ok(elems * dtype.size_in_bytes())
535 }
536
537 fn layer_sizes_in_bytes(
538 &self,
539 config: &str,
540 dtype: DType,
541 weight_pack_factor: usize,
542 _matformer_config: Option<&MatformerSliceConfig>,
543 ) -> Result<Vec<usize>> {
544 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
545
546 let per_layer_elems = {
547 let input_layernorm = cfg.hidden_size;
548 let post_attention_layernorm = cfg.hidden_size;
549
550 let size_in = cfg.hidden_size;
551 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
552 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
553 let q_proj = size_in * size_q / weight_pack_factor;
554 let k_proj = size_in * size_kv / weight_pack_factor;
555 let v_proj = size_in * size_kv / weight_pack_factor;
556 let o_proj = size_q * size_in / weight_pack_factor;
557
558 let h_size = cfg.hidden_size;
559 let i_size = cfg.intermediate_size;
560 let gate_proj = h_size * i_size / weight_pack_factor;
561 let up_proj = h_size * i_size / weight_pack_factor;
562 let down_proj = i_size * h_size / weight_pack_factor;
563
564 input_layernorm
565 + post_attention_layernorm
566 + q_proj
567 + k_proj
568 + v_proj
569 + o_proj
570 + gate_proj
571 + up_proj
572 + down_proj
573 };
574 Ok(vec![
575 per_layer_elems * dtype.size_in_bytes();
576 cfg.num_hidden_layers
577 ])
578 }
579
580 fn num_layers(&self, config: &str) -> Result<usize> {
581 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
582 Ok(cfg.num_hidden_layers)
583 }
584
585 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
586 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
587
588 let cfg = ModelConfigMetadata {
589 max_seq_len: cfg.max_position_embeddings,
590 num_layers: cfg.num_hidden_layers,
591 hidden_size: cfg.hidden_size,
592 num_kv_heads: cfg.num_key_value_heads,
593 num_attn_heads: cfg.num_attention_heads,
594 sliding_window: cfg.sliding_window,
595 k_head_dim: cfg.head_dim(),
596 v_head_dim: cfg.head_dim(),
597 };
598
599 Ok(Box::new(cfg))
600 }
601}
602
603pub struct GemmaLoader;
609
610impl NormalModelLoader for GemmaLoader {
611 fn load(
612 &self,
613 config: &str,
614 vb: ShardedVarBuilder,
615 normal_loading_metadata: NormalLoadingMetadata,
616 attention_mechanism: AttentionImplementation,
617 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
618 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
619
620 Ok(Box::new(models::gemma::Model::new(
621 &cfg,
622 vb,
623 self.is_gptx(config)?,
624 normal_loading_metadata,
625 attention_mechanism,
626 )?))
627 }
628 fn load_xlora(
629 &self,
630 config: &str,
631 vb: ShardedVarBuilder,
632 lora_config: &[((String, String), LoraConfig)],
633 xlora_config: Option<XLoraConfig>,
634 xlora_ordering: Ordering,
635 normal_loading_metadata: NormalLoadingMetadata,
636 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
637 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
638 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
639
640 Ok(Box::new(xlora_models::XLoraGemma::new(
641 &cfg,
642 vb,
643 lora_config,
644 xlora_config,
645 xlora_ordering,
646 self.is_gptx(config)?,
647 normal_loading_metadata,
648 preload_adapters,
649 )?))
650 }
651 fn is_gptx(&self, _: &str) -> Result<bool> {
652 Ok(true)
653 }
654 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
655 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
656 Ok(Box::new(cfg))
657 }
658}
659
660impl IsqModelLoader for GemmaLoader {
661 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
662 Ok(vec![
663 Regex::new(r"lm_head\.(weight|bias)$")?,
664 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
666 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
667 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
668 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
669 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
671 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
672 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
673 ])
674 }
675 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
676 self.isq_layer_regexes(config)
677 }
678}
679
680impl DeviceMappedModelLoader for GemmaLoader {
681 fn mapped_max_act_size_elems(
682 &self,
683 config: &str,
684 params: &AutoDeviceMapParams,
685 prompt_chunksize: usize,
686 ) -> Result<usize> {
687 let AutoDeviceMapParams::Text {
688 max_seq_len: _,
689 max_batch_size,
690 } = params
691 else {
692 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
693 };
694
695 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
696
697 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
698 }
699 fn non_mapped_max_act_size_elems(
700 &self,
701 _config: &str,
702 _params: &AutoDeviceMapParams,
703 ) -> Result<usize> {
704 Ok(0)
705 }
706
707 fn non_mapped_size_in_bytes(
708 &self,
709 config: &str,
710 dtype: DType,
711 weight_pack_factor: usize,
712 _matformer_config: Option<&MatformerSliceConfig>,
713 ) -> Result<usize> {
714 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
715
716 let elems = {
717 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
718 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
720 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
721 } else {
722 0
723 };
724 let norm = cfg.hidden_size;
725 embed_tokens + lm_head + norm
726 };
727 Ok(elems * dtype.size_in_bytes())
728 }
729
730 fn layer_sizes_in_bytes(
731 &self,
732 config: &str,
733 dtype: DType,
734 weight_pack_factor: usize,
735 _matformer_config: Option<&MatformerSliceConfig>,
736 ) -> Result<Vec<usize>> {
737 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
738
739 let per_layer_elems = {
740 let input_layernorm = cfg.hidden_size;
741 let post_attention_layernorm = cfg.hidden_size;
742
743 let size_in = cfg.hidden_size;
744 let size_q = cfg.head_dim * cfg.num_attention_heads;
745 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
746 let q_proj =
747 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
748 let k_proj =
749 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
750 let v_proj =
751 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
752 let o_proj =
753 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
754
755 let h_size = cfg.hidden_size;
756 let i_size = cfg.intermediate_size;
757 let gate_proj = h_size * i_size / weight_pack_factor;
758 let up_proj = h_size * i_size / weight_pack_factor;
759 let down_proj = i_size * h_size / weight_pack_factor;
760
761 input_layernorm
762 + post_attention_layernorm
763 + q_proj
764 + k_proj
765 + v_proj
766 + o_proj
767 + gate_proj
768 + up_proj
769 + down_proj
770 };
771 Ok(vec![
772 per_layer_elems * dtype.size_in_bytes();
773 cfg.num_hidden_layers
774 ])
775 }
776
777 fn num_layers(&self, config: &str) -> Result<usize> {
778 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
779 Ok(cfg.num_hidden_layers)
780 }
781
782 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
783 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
784
785 let cfg = ModelConfigMetadata {
786 max_seq_len: cfg.max_position_embeddings,
787 num_layers: cfg.num_hidden_layers,
788 hidden_size: cfg.hidden_size,
789 num_kv_heads: cfg.num_key_value_heads,
790 num_attn_heads: cfg.num_attention_heads,
791 sliding_window: None,
792 k_head_dim: cfg.head_dim,
793 v_head_dim: cfg.head_dim,
794 };
795
796 Ok(Box::new(cfg))
797 }
798}
799
800pub struct LlamaLoader;
806
807impl NormalModelLoader for LlamaLoader {
808 fn load(
809 &self,
810 config: &str,
811 vb: ShardedVarBuilder,
812 normal_loading_metadata: NormalLoadingMetadata,
813 attention_mechanism: AttentionImplementation,
814 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
815 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
816
817 Ok(Box::new(models::llama::Llama::new(
818 &cfg,
819 vb,
820 self.is_gptx(config)?,
821 normal_loading_metadata,
822 attention_mechanism,
823 )?))
824 }
825 fn load_xlora(
826 &self,
827 config: &str,
828 vb: ShardedVarBuilder,
829 lora_config: &[((String, String), LoraConfig)],
830 xlora_config: Option<XLoraConfig>,
831 xlora_ordering: Ordering,
832 normal_loading_metadata: NormalLoadingMetadata,
833 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
834 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
835 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
836
837 Ok(Box::new(xlora_models::XLoraLlama::new(
838 &cfg,
839 vb,
840 lora_config,
841 xlora_config,
842 xlora_ordering,
843 self.is_gptx(config)?,
844 normal_loading_metadata,
845 preload_adapters,
846 )?))
847 }
848 fn is_gptx(&self, _: &str) -> Result<bool> {
849 Ok(true)
850 }
851 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
852 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
853 Ok(Box::new(cfg))
854 }
855}
856
857impl IsqModelLoader for LlamaLoader {
858 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
859 Ok(vec![
860 Regex::new(r"lm_head\.(weight|bias)$")?,
861 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
863 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
864 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
865 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
866 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
868 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
869 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
870 ])
871 }
872 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
873 self.isq_layer_regexes(config)
874 }
875}
876
877impl DeviceMappedModelLoader for LlamaLoader {
878 fn mapped_max_act_size_elems(
879 &self,
880 config: &str,
881 params: &AutoDeviceMapParams,
882 prompt_chunksize: usize,
883 ) -> Result<usize> {
884 let AutoDeviceMapParams::Text {
885 max_seq_len: _,
886 max_batch_size,
887 } = params
888 else {
889 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
890 };
891
892 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
893
894 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
895 }
896 fn non_mapped_max_act_size_elems(
897 &self,
898 _config: &str,
899 _params: &AutoDeviceMapParams,
900 ) -> Result<usize> {
901 Ok(0)
902 }
903
904 fn non_mapped_size_in_bytes(
905 &self,
906 config: &str,
907 dtype: DType,
908 weight_pack_factor: usize,
909 _matformer_config: Option<&MatformerSliceConfig>,
910 ) -> Result<usize> {
911 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
912
913 let elems = {
914 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
915 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
917 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
918 } else {
919 0
920 };
921 let norm = cfg.hidden_size;
922 embed_tokens + lm_head + norm
923 };
924 Ok(elems * dtype.size_in_bytes())
925 }
926
927 fn layer_sizes_in_bytes(
928 &self,
929 config: &str,
930 dtype: DType,
931 weight_pack_factor: usize,
932 _matformer_config: Option<&MatformerSliceConfig>,
933 ) -> Result<Vec<usize>> {
934 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
935
936 let per_layer_elems = {
937 let input_layernorm = cfg.hidden_size;
938 let post_attention_layernorm = cfg.hidden_size;
939
940 let size_in = cfg.hidden_size;
941 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
942 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
943 let q_proj = size_in * size_q / weight_pack_factor;
944 let k_proj = size_in * size_kv / weight_pack_factor;
945 let v_proj = size_in * size_kv / weight_pack_factor;
946 let o_proj = size_q * size_in / weight_pack_factor;
947
948 let h_size = cfg.hidden_size;
949 let i_size = cfg.intermediate_size;
950 let gate_proj = h_size * i_size / weight_pack_factor;
951 let up_proj = h_size * i_size / weight_pack_factor;
952 let down_proj = i_size * h_size / weight_pack_factor;
953
954 input_layernorm
955 + post_attention_layernorm
956 + q_proj
957 + k_proj
958 + v_proj
959 + o_proj
960 + gate_proj
961 + up_proj
962 + down_proj
963 };
964 Ok(vec![
965 per_layer_elems * dtype.size_in_bytes();
966 cfg.num_hidden_layers
967 ])
968 }
969
970 fn num_layers(&self, config: &str) -> Result<usize> {
971 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
972
973 Ok(cfg.num_hidden_layers)
974 }
975 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
976 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
977
978 let cfg = ModelConfigMetadata {
979 max_seq_len: cfg.max_position_embeddings,
980 num_layers: cfg.num_hidden_layers,
981 hidden_size: cfg.hidden_size,
982 num_kv_heads: cfg.num_key_value_heads,
983 num_attn_heads: cfg.num_attention_heads,
984 sliding_window: None,
985 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
986 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
987 };
988
989 Ok(Box::new(cfg))
990 }
991}
992
993pub struct MixtralLoader;
996
997impl NormalModelLoader for MixtralLoader {
998 fn load(
999 &self,
1000 config: &str,
1001 vb: ShardedVarBuilder,
1002 normal_loading_metadata: NormalLoadingMetadata,
1003 attention_mechanism: AttentionImplementation,
1004 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1005 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1006
1007 Ok(Box::new(models::mixtral::Model::new(
1008 &cfg,
1009 vb,
1010 self.is_gptx(config)?,
1011 normal_loading_metadata,
1012 attention_mechanism,
1013 )?))
1014 }
1015 fn load_xlora(
1016 &self,
1017 config: &str,
1018 vb: ShardedVarBuilder,
1019 lora_config: &[((String, String), LoraConfig)],
1020 xlora_config: Option<XLoraConfig>,
1021 xlora_ordering: Ordering,
1022 normal_loading_metadata: NormalLoadingMetadata,
1023 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1024 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1025 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1026
1027 Ok(Box::new(xlora_models::XLoraMixtral::new(
1028 &cfg,
1029 vb,
1030 lora_config,
1031 xlora_config,
1032 xlora_ordering,
1033 self.is_gptx(config)?,
1034 normal_loading_metadata,
1035 preload_adapters,
1036 )?))
1037 }
1038 fn is_gptx(&self, _: &str) -> Result<bool> {
1039 Ok(true)
1040 }
1041 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1042 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1043
1044 Ok(Box::new(cfg))
1045 }
1046}
1047
1048impl IsqModelLoader for MixtralLoader {
1049 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1050 Ok(vec![
1051 Regex::new(r"lm_head\.(weight|bias)$")?,
1052 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1054 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1055 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1056 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1057 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1059 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1060 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1061 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1062 ])
1063 }
1064 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1065 self.isq_layer_regexes(config)
1066 }
1067}
1068
1069impl DeviceMappedModelLoader for MixtralLoader {
1070 fn mapped_max_act_size_elems(
1071 &self,
1072 config: &str,
1073 params: &AutoDeviceMapParams,
1074 prompt_chunksize: usize,
1075 ) -> Result<usize> {
1076 let AutoDeviceMapParams::Text {
1077 max_seq_len: _,
1078 max_batch_size,
1079 } = params
1080 else {
1081 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1082 };
1083
1084 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1085
1086 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1087 }
1088 fn non_mapped_max_act_size_elems(
1089 &self,
1090 _config: &str,
1091 _params: &AutoDeviceMapParams,
1092 ) -> Result<usize> {
1093 Ok(0)
1094 }
1095
1096 fn non_mapped_size_in_bytes(
1097 &self,
1098 config: &str,
1099 dtype: DType,
1100 weight_pack_factor: usize,
1101 _matformer_config: Option<&MatformerSliceConfig>,
1102 ) -> Result<usize> {
1103 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1104
1105 let elems = {
1106 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1107 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1109 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1110 } else {
1111 0
1112 };
1113 let norm = cfg.hidden_size;
1114 embed_tokens + lm_head + norm
1115 };
1116 Ok(elems * dtype.size_in_bytes())
1117 }
1118
1119 fn layer_sizes_in_bytes(
1120 &self,
1121 config: &str,
1122 dtype: DType,
1123 weight_pack_factor: usize,
1124 _matformer_config: Option<&MatformerSliceConfig>,
1125 ) -> Result<Vec<usize>> {
1126 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128 let per_layer_elems = {
1129 let input_layernorm = cfg.hidden_size;
1130 let post_attention_layernorm = cfg.hidden_size;
1131
1132 let size_in = cfg.hidden_size;
1133 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1134 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1135 let q_proj = size_in * size_q / weight_pack_factor;
1136 let k_proj = size_in * size_kv / weight_pack_factor;
1137 let v_proj = size_in * size_kv / weight_pack_factor;
1138 let o_proj = size_q * size_in / weight_pack_factor;
1139
1140 let moe_block = {
1141 let gate = cfg.hidden_size * cfg.num_local_experts;
1142 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1144 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1145 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1146 gate + cfg.num_local_experts * w1
1147 + cfg.num_local_experts * w2
1148 + cfg.num_local_experts * w3
1149 };
1150
1151 input_layernorm
1152 + post_attention_layernorm
1153 + q_proj
1154 + k_proj
1155 + v_proj
1156 + o_proj
1157 + moe_block
1158 };
1159 Ok(vec![
1160 per_layer_elems * dtype.size_in_bytes();
1161 cfg.num_hidden_layers
1162 ])
1163 }
1164
1165 fn num_layers(&self, config: &str) -> Result<usize> {
1166 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1167
1168 Ok(cfg.num_hidden_layers)
1169 }
1170
1171 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1172 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1173
1174 let cfg = ModelConfigMetadata {
1175 max_seq_len: cfg.max_position_embeddings,
1176 num_layers: cfg.num_hidden_layers,
1177 hidden_size: cfg.hidden_size,
1178 num_kv_heads: cfg.num_key_value_heads,
1179 num_attn_heads: cfg.num_attention_heads,
1180 sliding_window: cfg.sliding_window,
1181 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1182 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1183 };
1184
1185 Ok(Box::new(cfg))
1186 }
1187}
1188
1189pub struct Phi2Loader;
1195
1196impl NormalModelLoader for Phi2Loader {
1197 fn load(
1198 &self,
1199 config: &str,
1200 vb: ShardedVarBuilder,
1201 normal_loading_metadata: NormalLoadingMetadata,
1202 attention_mechanism: AttentionImplementation,
1203 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1204 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1205
1206 Ok(Box::new(models::phi2::Model::new(
1207 &cfg,
1208 vb,
1209 self.is_gptx(config)?,
1210 normal_loading_metadata,
1211 attention_mechanism,
1212 )?))
1213 }
1214 fn load_xlora(
1215 &self,
1216 config: &str,
1217 vb: ShardedVarBuilder,
1218 lora_config: &[((String, String), LoraConfig)],
1219 xlora_config: Option<XLoraConfig>,
1220 xlora_ordering: Ordering,
1221 normal_loading_metadata: NormalLoadingMetadata,
1222 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1223 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1224 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1225
1226 Ok(Box::new(xlora_models::XLoraPhi2::new(
1227 &cfg,
1228 vb,
1229 lora_config,
1230 xlora_config,
1231 xlora_ordering,
1232 self.is_gptx(config)?,
1233 normal_loading_metadata,
1234 preload_adapters,
1235 )?))
1236 }
1237 fn is_gptx(&self, _: &str) -> Result<bool> {
1238 Ok(true)
1239 }
1240 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1241 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243 Ok(Box::new(cfg))
1244 }
1245}
1246
1247impl IsqModelLoader for Phi2Loader {
1248 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1249 Ok(vec![
1250 Regex::new(r"lm_head\.(weight|bias)$")?,
1251 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1253 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1254 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1255 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1256 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1258 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1259 ])
1260 }
1261 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1262 self.isq_layer_regexes(config)
1263 }
1264}
1265
1266impl DeviceMappedModelLoader for Phi2Loader {
1267 fn mapped_max_act_size_elems(
1268 &self,
1269 config: &str,
1270 params: &AutoDeviceMapParams,
1271 prompt_chunksize: usize,
1272 ) -> Result<usize> {
1273 let AutoDeviceMapParams::Text {
1274 max_seq_len: _,
1275 max_batch_size,
1276 } = params
1277 else {
1278 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1279 };
1280
1281 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1282
1283 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1284 }
1285 fn non_mapped_max_act_size_elems(
1286 &self,
1287 _config: &str,
1288 _params: &AutoDeviceMapParams,
1289 ) -> Result<usize> {
1290 Ok(0)
1291 }
1292
1293 fn non_mapped_size_in_bytes(
1294 &self,
1295 config: &str,
1296 dtype: DType,
1297 weight_pack_factor: usize,
1298 _matformer_config: Option<&MatformerSliceConfig>,
1299 ) -> Result<usize> {
1300 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1301
1302 let elems = {
1303 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1304 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1306 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1307 } else {
1308 0
1309 };
1310 let norm = cfg.hidden_size;
1311 embed_tokens + lm_head + norm
1312 };
1313 Ok(elems * dtype.size_in_bytes())
1314 }
1315
1316 fn layer_sizes_in_bytes(
1317 &self,
1318 config: &str,
1319 dtype: DType,
1320 weight_pack_factor: usize,
1321 _matformer_config: Option<&MatformerSliceConfig>,
1322 ) -> Result<Vec<usize>> {
1323 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1324
1325 let per_layer_elems = {
1326 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1327
1328 let size_in = cfg.hidden_size;
1329 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1330 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1331 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1332 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1333 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1334 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1335 let (q_norm, k_norm) = if cfg.qk_layernorm {
1336 (cfg.head_dim(), cfg.head_dim())
1337 } else {
1338 (0, 0)
1339 };
1340
1341 let h_size = cfg.hidden_size;
1342 let i_size = cfg.intermediate_size;
1343 let fc1 = h_size * i_size / weight_pack_factor;
1344 let fc2 = h_size * i_size / weight_pack_factor;
1345
1346 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1347 };
1348 Ok(vec![
1349 per_layer_elems * dtype.size_in_bytes();
1350 cfg.num_hidden_layers
1351 ])
1352 }
1353
1354 fn num_layers(&self, config: &str) -> Result<usize> {
1355 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1356
1357 Ok(cfg.num_hidden_layers)
1358 }
1359
1360 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1361 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1362
1363 let cfg = ModelConfigMetadata {
1364 max_seq_len: cfg.max_position_embeddings,
1365 num_layers: cfg.num_hidden_layers,
1366 hidden_size: cfg.hidden_size,
1367 num_kv_heads: cfg.num_key_value_heads(),
1368 num_attn_heads: cfg.num_attention_heads,
1369 sliding_window: None,
1370 k_head_dim: cfg.head_dim(),
1371 v_head_dim: cfg.head_dim(),
1372 };
1373
1374 Ok(Box::new(cfg))
1375 }
1376}
1377
1378pub struct Phi3Loader;
1384
1385impl NormalModelLoader for Phi3Loader {
1386 fn load(
1387 &self,
1388 config: &str,
1389 vb: ShardedVarBuilder,
1390 normal_loading_metadata: NormalLoadingMetadata,
1391 attention_mechanism: AttentionImplementation,
1392 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1393 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1394
1395 Ok(Box::new(models::phi3::Model::new(
1396 &cfg,
1397 vb,
1398 self.is_gptx(config)?,
1399 normal_loading_metadata,
1400 attention_mechanism,
1401 )?))
1402 }
1403 fn load_xlora(
1404 &self,
1405 config: &str,
1406 vb: ShardedVarBuilder,
1407 lora_config: &[((String, String), LoraConfig)],
1408 xlora_config: Option<XLoraConfig>,
1409 xlora_ordering: Ordering,
1410 normal_loading_metadata: NormalLoadingMetadata,
1411 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1412 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1413 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1414
1415 Ok(Box::new(xlora_models::XLoraPhi3::new(
1416 &cfg,
1417 vb,
1418 lora_config,
1419 xlora_config,
1420 xlora_ordering,
1421 self.is_gptx(config)?,
1422 normal_loading_metadata,
1423 preload_adapters,
1424 )?))
1425 }
1426 fn is_gptx(&self, _: &str) -> Result<bool> {
1427 Ok(true)
1428 }
1429 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1430 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1431
1432 Ok(Box::new(cfg))
1433 }
1434}
1435
1436impl IsqModelLoader for Phi3Loader {
1437 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1438 Ok(vec![
1439 Regex::new(r"lm_head\.(weight|bias)$")?,
1440 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1442 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1443 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1445 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1446 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1447 ])
1448 }
1449 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1450 self.isq_layer_regexes(config)
1451 }
1452}
1453
1454impl DeviceMappedModelLoader for Phi3Loader {
1455 fn mapped_max_act_size_elems(
1456 &self,
1457 config: &str,
1458 params: &AutoDeviceMapParams,
1459 prompt_chunksize: usize,
1460 ) -> Result<usize> {
1461 let AutoDeviceMapParams::Text {
1462 max_seq_len: _,
1463 max_batch_size,
1464 } = params
1465 else {
1466 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1467 };
1468
1469 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1470
1471 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1472 }
1473 fn non_mapped_max_act_size_elems(
1474 &self,
1475 _config: &str,
1476 _params: &AutoDeviceMapParams,
1477 ) -> Result<usize> {
1478 Ok(0)
1479 }
1480
1481 fn non_mapped_size_in_bytes(
1482 &self,
1483 config: &str,
1484 dtype: DType,
1485 weight_pack_factor: usize,
1486 _matformer_config: Option<&MatformerSliceConfig>,
1487 ) -> Result<usize> {
1488 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1489
1490 let elems = {
1491 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1492 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1494 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1495 } else {
1496 0
1497 };
1498 let norm = cfg.hidden_size;
1499 embed_tokens + lm_head + norm
1500 };
1501 Ok(elems * dtype.size_in_bytes())
1502 }
1503
1504 fn layer_sizes_in_bytes(
1505 &self,
1506 config: &str,
1507 dtype: DType,
1508 weight_pack_factor: usize,
1509 _matformer_config: Option<&MatformerSliceConfig>,
1510 ) -> Result<Vec<usize>> {
1511 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513 let per_layer_elems = {
1514 let input_layernorm = cfg.hidden_size;
1515 let post_attention_layernorm = cfg.hidden_size;
1516
1517 let size_in = cfg.hidden_size;
1518 let head_dim = cfg.head_dim();
1519 let op_size =
1520 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1521 let qkv_proj = size_in * op_size / weight_pack_factor;
1522 let o_proj =
1523 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1524
1525 let h_size = cfg.hidden_size;
1526 let i_size = cfg.intermediate_size;
1527 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1528 let down_proj = h_size * i_size / weight_pack_factor;
1529
1530 input_layernorm
1531 + post_attention_layernorm
1532 + qkv_proj
1533 + o_proj
1534 + gate_up_proj
1535 + down_proj
1536 };
1537 Ok(vec![
1538 per_layer_elems * dtype.size_in_bytes();
1539 cfg.num_hidden_layers
1540 ])
1541 }
1542
1543 fn num_layers(&self, config: &str) -> Result<usize> {
1544 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1545
1546 Ok(cfg.num_hidden_layers)
1547 }
1548
1549 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1550 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1551
1552 let cfg = ModelConfigMetadata {
1553 max_seq_len: cfg.max_position_embeddings,
1554 num_layers: cfg.num_hidden_layers,
1555 hidden_size: cfg.hidden_size,
1556 num_kv_heads: cfg.num_key_value_heads,
1557 num_attn_heads: cfg.num_attention_heads,
1558 sliding_window: cfg.sliding_window,
1559 k_head_dim: cfg.head_dim(),
1560 v_head_dim: cfg.head_dim(),
1561 };
1562
1563 Ok(Box::new(cfg))
1564 }
1565}
1566
1567pub struct Qwen2Loader;
1573
1574impl NormalModelLoader for Qwen2Loader {
1575 fn load(
1576 &self,
1577 config: &str,
1578 vb: ShardedVarBuilder,
1579 normal_loading_metadata: NormalLoadingMetadata,
1580 attention_mechanism: AttentionImplementation,
1581 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1582 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1583
1584 Ok(Box::new(models::qwen2::Model::new(
1585 &cfg,
1586 vb,
1587 self.is_gptx(config)?,
1588 normal_loading_metadata,
1589 attention_mechanism,
1590 )?))
1591 }
1592 fn load_xlora(
1593 &self,
1594 _config: &str,
1595 _vb: ShardedVarBuilder,
1596 _lora_config: &[((String, String), LoraConfig)],
1597 _xlora_config: Option<XLoraConfig>,
1598 _xlora_ordering: Ordering,
1599 _normal_loading_metadata: NormalLoadingMetadata,
1600 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1601 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1602 todo!()
1603 }
1604 fn is_gptx(&self, _: &str) -> Result<bool> {
1605 Ok(true)
1606 }
1607 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1608 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1609
1610 Ok(Box::new(cfg))
1611 }
1612}
1613
1614impl IsqModelLoader for Qwen2Loader {
1615 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1616 Ok(vec![
1617 Regex::new(r"lm_head\.(weight|bias)$")?,
1618 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1620 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1621 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1622 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1623 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1625 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1626 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1627 ])
1628 }
1629 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1630 self.isq_layer_regexes(config)
1631 }
1632}
1633
1634impl DeviceMappedModelLoader for Qwen2Loader {
1635 fn mapped_max_act_size_elems(
1636 &self,
1637 config: &str,
1638 params: &AutoDeviceMapParams,
1639 prompt_chunksize: usize,
1640 ) -> Result<usize> {
1641 let AutoDeviceMapParams::Text {
1642 max_seq_len: _,
1643 max_batch_size,
1644 } = params
1645 else {
1646 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1647 };
1648
1649 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1650
1651 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1652 }
1653 fn non_mapped_max_act_size_elems(
1654 &self,
1655 _config: &str,
1656 _params: &AutoDeviceMapParams,
1657 ) -> Result<usize> {
1658 Ok(0)
1659 }
1660
1661 fn non_mapped_size_in_bytes(
1662 &self,
1663 config: &str,
1664 dtype: DType,
1665 weight_pack_factor: usize,
1666 _matformer_config: Option<&MatformerSliceConfig>,
1667 ) -> Result<usize> {
1668 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1669
1670 let elems = {
1671 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1672 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1674 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1675 } else {
1676 0
1677 };
1678 let norm = cfg.hidden_size;
1679 embed_tokens + lm_head + norm
1680 };
1681 Ok(elems * dtype.size_in_bytes())
1682 }
1683
1684 fn layer_sizes_in_bytes(
1685 &self,
1686 config: &str,
1687 dtype: DType,
1688 weight_pack_factor: usize,
1689 _matformer_config: Option<&MatformerSliceConfig>,
1690 ) -> Result<Vec<usize>> {
1691 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1692
1693 let per_layer_elems = {
1694 let input_layernorm = cfg.hidden_size;
1695 let post_attention_layernorm = cfg.hidden_size;
1696
1697 let size_in = cfg.hidden_size;
1698 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1699 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1700 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1701 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1702 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1703 let o_proj = size_q * size_in / weight_pack_factor;
1704
1705 let h_size = cfg.hidden_size;
1706 let i_size = cfg.intermediate_size;
1707 let gate_proj = h_size * i_size / weight_pack_factor;
1708 let up_proj = h_size * i_size / weight_pack_factor;
1709 let down_proj = i_size * h_size / weight_pack_factor;
1710
1711 input_layernorm
1712 + post_attention_layernorm
1713 + q_proj
1714 + k_proj
1715 + v_proj
1716 + o_proj
1717 + gate_proj
1718 + up_proj
1719 + down_proj
1720 };
1721 Ok(vec![
1722 per_layer_elems * dtype.size_in_bytes();
1723 cfg.num_hidden_layers
1724 ])
1725 }
1726
1727 fn num_layers(&self, config: &str) -> Result<usize> {
1728 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1729
1730 Ok(cfg.num_hidden_layers)
1731 }
1732
1733 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1734 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1735
1736 let cfg = ModelConfigMetadata {
1737 max_seq_len: cfg.max_position_embeddings,
1738 num_layers: cfg.num_hidden_layers,
1739 hidden_size: cfg.hidden_size,
1740 num_kv_heads: cfg.num_key_value_heads,
1741 num_attn_heads: cfg.num_attention_heads,
1742 sliding_window: cfg.sliding_window,
1743 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1744 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1745 };
1746
1747 Ok(Box::new(cfg))
1748 }
1749}
1750
1751pub struct Gemma2Loader;
1757
1758impl NormalModelLoader for Gemma2Loader {
1759 fn load(
1760 &self,
1761 config: &str,
1762 vb: ShardedVarBuilder,
1763 normal_loading_metadata: NormalLoadingMetadata,
1764 attention_mechanism: AttentionImplementation,
1765 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1766 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1767
1768 Ok(Box::new(models::gemma2::Model::new(
1769 &cfg,
1770 vb,
1771 self.is_gptx(config)?,
1772 normal_loading_metadata,
1773 attention_mechanism,
1774 )?))
1775 }
1776 fn load_xlora(
1777 &self,
1778 config: &str,
1779 vb: ShardedVarBuilder,
1780 lora_config: &[((String, String), LoraConfig)],
1781 xlora_config: Option<XLoraConfig>,
1782 xlora_ordering: Ordering,
1783 normal_loading_metadata: NormalLoadingMetadata,
1784 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1785 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1786 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1787
1788 Ok(Box::new(xlora_models::XLoraGemma2::new(
1789 &cfg,
1790 vb,
1791 lora_config,
1792 xlora_config,
1793 xlora_ordering,
1794 self.is_gptx(config)?,
1795 normal_loading_metadata,
1796 preload_adapters,
1797 )?))
1798 }
1799 fn is_gptx(&self, _: &str) -> Result<bool> {
1800 Ok(true)
1801 }
1802 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1803 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1804
1805 Ok(Box::new(cfg))
1806 }
1807}
1808
1809impl IsqModelLoader for Gemma2Loader {
1810 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1811 Ok(vec![
1812 Regex::new(r"lm_head\.(weight|bias)$")?,
1813 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1815 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1816 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1817 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1818 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1820 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1821 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1822 ])
1823 }
1824 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1825 self.isq_layer_regexes(config)
1826 }
1827}
1828
1829impl DeviceMappedModelLoader for Gemma2Loader {
1830 fn mapped_max_act_size_elems(
1831 &self,
1832 config: &str,
1833 params: &AutoDeviceMapParams,
1834 prompt_chunksize: usize,
1835 ) -> Result<usize> {
1836 let AutoDeviceMapParams::Text {
1837 max_seq_len: _,
1838 max_batch_size,
1839 } = params
1840 else {
1841 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1842 };
1843
1844 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1845
1846 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1847 }
1848 fn non_mapped_max_act_size_elems(
1849 &self,
1850 _config: &str,
1851 _params: &AutoDeviceMapParams,
1852 ) -> Result<usize> {
1853 Ok(0)
1854 }
1855
1856 fn non_mapped_size_in_bytes(
1857 &self,
1858 config: &str,
1859 dtype: DType,
1860 weight_pack_factor: usize,
1861 _matformer_config: Option<&MatformerSliceConfig>,
1862 ) -> Result<usize> {
1863 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1864
1865 let elems = {
1866 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1867 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1869 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1870 } else {
1871 0
1872 };
1873 let norm = cfg.hidden_size;
1874 embed_tokens + lm_head + norm
1875 };
1876 Ok(elems * dtype.size_in_bytes())
1877 }
1878
1879 fn layer_sizes_in_bytes(
1880 &self,
1881 config: &str,
1882 dtype: DType,
1883 weight_pack_factor: usize,
1884 _matformer_config: Option<&MatformerSliceConfig>,
1885 ) -> Result<Vec<usize>> {
1886 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1887
1888 let per_layer_elems = {
1889 let input_layernorm = cfg.hidden_size;
1890 let post_attention_layernorm = cfg.hidden_size;
1891
1892 let size_in = cfg.hidden_size;
1893 let size_q = cfg.head_dim * cfg.num_attention_heads;
1894 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1895 let q_proj =
1896 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1897 let k_proj =
1898 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1899 let v_proj =
1900 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1901 let o_proj =
1902 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1903
1904 let h_size = cfg.hidden_size;
1905 let i_size = cfg.intermediate_size;
1906 let gate_proj = h_size * i_size / weight_pack_factor;
1907 let up_proj = h_size * i_size / weight_pack_factor;
1908 let down_proj = i_size * h_size / weight_pack_factor;
1909
1910 input_layernorm
1911 + post_attention_layernorm
1912 + q_proj
1913 + k_proj
1914 + v_proj
1915 + o_proj
1916 + gate_proj
1917 + up_proj
1918 + down_proj
1919 };
1920 Ok(vec![
1921 per_layer_elems * dtype.size_in_bytes();
1922 cfg.num_hidden_layers
1923 ])
1924 }
1925
1926 fn num_layers(&self, config: &str) -> Result<usize> {
1927 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1928
1929 Ok(cfg.num_hidden_layers)
1930 }
1931 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1932 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1933
1934 let cfg = ModelConfigMetadata {
1935 max_seq_len: cfg.max_position_embeddings,
1936 num_layers: cfg.num_hidden_layers,
1937 hidden_size: cfg.hidden_size,
1938 num_kv_heads: cfg.num_key_value_heads,
1939 num_attn_heads: cfg.num_attention_heads,
1940 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1942 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1943 };
1944
1945 Ok(Box::new(cfg))
1946 }
1947}
1948
1949pub struct Starcoder2Loader;
1955
1956impl NormalModelLoader for Starcoder2Loader {
1957 fn load(
1958 &self,
1959 config: &str,
1960 vb: ShardedVarBuilder,
1961 normal_loading_metadata: NormalLoadingMetadata,
1962 attention_mechanism: AttentionImplementation,
1963 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1964 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1965
1966 Ok(Box::new(models::starcoder2::Model::new(
1967 &cfg,
1968 vb,
1969 self.is_gptx(config)?,
1970 normal_loading_metadata,
1971 attention_mechanism,
1972 )?))
1973 }
1974 fn load_xlora(
1975 &self,
1976 config: &str,
1977 vb: ShardedVarBuilder,
1978 lora_config: &[((String, String), LoraConfig)],
1979 xlora_config: Option<XLoraConfig>,
1980 xlora_ordering: Ordering,
1981 normal_loading_metadata: NormalLoadingMetadata,
1982 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1983 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1984 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1985
1986 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1987 &cfg,
1988 vb,
1989 lora_config,
1990 xlora_config,
1991 xlora_ordering,
1992 self.is_gptx(config)?,
1993 normal_loading_metadata,
1994 preload_adapters,
1995 )?))
1996 }
1997 fn is_gptx(&self, _: &str) -> Result<bool> {
1998 Ok(true)
1999 }
2000 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2001 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2002
2003 Ok(Box::new(cfg))
2004 }
2005}
2006
2007impl IsqModelLoader for Starcoder2Loader {
2008 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2009 Ok(vec![
2010 Regex::new(r"lm_head\.(weight|bias)$")?,
2011 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2013 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2014 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2015 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2016 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2018 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2019 ])
2020 }
2021 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2022 self.isq_layer_regexes(config)
2023 }
2024}
2025
2026impl DeviceMappedModelLoader for Starcoder2Loader {
2027 fn mapped_max_act_size_elems(
2028 &self,
2029 config: &str,
2030 params: &AutoDeviceMapParams,
2031 prompt_chunksize: usize,
2032 ) -> Result<usize> {
2033 let AutoDeviceMapParams::Text {
2034 max_seq_len: _,
2035 max_batch_size,
2036 } = params
2037 else {
2038 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2039 };
2040
2041 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2042
2043 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2044 }
2045 fn non_mapped_max_act_size_elems(
2046 &self,
2047 _config: &str,
2048 _params: &AutoDeviceMapParams,
2049 ) -> Result<usize> {
2050 Ok(0)
2051 }
2052
2053 fn non_mapped_size_in_bytes(
2054 &self,
2055 config: &str,
2056 dtype: DType,
2057 weight_pack_factor: usize,
2058 _matformer_config: Option<&MatformerSliceConfig>,
2059 ) -> Result<usize> {
2060 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2061
2062 let elems = {
2063 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2064 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2066 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2067 } else {
2068 0
2069 };
2070 let norm = cfg.hidden_size + cfg.hidden_size;
2071 embed_tokens + lm_head + norm
2072 };
2073 Ok(elems * dtype.size_in_bytes())
2074 }
2075
2076 fn layer_sizes_in_bytes(
2077 &self,
2078 config: &str,
2079 dtype: DType,
2080 weight_pack_factor: usize,
2081 _matformer_config: Option<&MatformerSliceConfig>,
2082 ) -> Result<Vec<usize>> {
2083 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2084
2085 let per_layer_elems = {
2086 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2087 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2088
2089 let size_in = cfg.hidden_size;
2090 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2091 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2092 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2093 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2094 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2095 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2096
2097 let h_size = cfg.hidden_size;
2098 let i_size = cfg.intermediate_size;
2099 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2100 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2101
2102 input_layernorm
2103 + post_attention_layernorm
2104 + q_proj
2105 + k_proj
2106 + v_proj
2107 + o_proj
2108 + fc1
2109 + fc2
2110 };
2111 Ok(vec![
2112 per_layer_elems * dtype.size_in_bytes();
2113 cfg.num_hidden_layers
2114 ])
2115 }
2116
2117 fn num_layers(&self, config: &str) -> Result<usize> {
2118 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2119
2120 Ok(cfg.num_hidden_layers)
2121 }
2122
2123 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2124 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2125
2126 let cfg = ModelConfigMetadata {
2127 max_seq_len: cfg.max_position_embeddings,
2128 num_layers: cfg.num_hidden_layers,
2129 hidden_size: cfg.hidden_size,
2130 num_kv_heads: cfg.num_key_value_heads,
2131 num_attn_heads: cfg.num_attention_heads,
2132 sliding_window: cfg.sliding_window,
2133 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2134 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2135 };
2136
2137 Ok(Box::new(cfg))
2138 }
2139}
2140
2141pub struct Phi3_5MoELoader;
2147
2148impl NormalModelLoader for Phi3_5MoELoader {
2149 fn load(
2150 &self,
2151 config: &str,
2152 vb: ShardedVarBuilder,
2153 normal_loading_metadata: NormalLoadingMetadata,
2154 attention_mechanism: AttentionImplementation,
2155 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2156 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2157
2158 Ok(Box::new(models::phi3_5_moe::Model::new(
2159 &cfg,
2160 vb,
2161 self.is_gptx(config)?,
2162 normal_loading_metadata,
2163 attention_mechanism,
2164 )?))
2165 }
2166 fn load_xlora(
2167 &self,
2168 config: &str,
2169 vb: ShardedVarBuilder,
2170 lora_config: &[((String, String), LoraConfig)],
2171 xlora_config: Option<XLoraConfig>,
2172 xlora_ordering: Ordering,
2173 normal_loading_metadata: NormalLoadingMetadata,
2174 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2175 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2176 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2177
2178 Ok(Box::new(xlora_models::XLoraPhi3::new(
2179 &cfg,
2180 vb,
2181 lora_config,
2182 xlora_config,
2183 xlora_ordering,
2184 self.is_gptx(config)?,
2185 normal_loading_metadata,
2186 preload_adapters,
2187 )?))
2188 }
2189 fn is_gptx(&self, _: &str) -> Result<bool> {
2190 Ok(true)
2191 }
2192 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2193 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2194
2195 Ok(Box::new(cfg))
2196 }
2197}
2198
2199impl IsqModelLoader for Phi3_5MoELoader {
2200 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2201 Ok(vec![
2202 Regex::new(r"lm_head\.(weight|bias)$")?,
2203 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2205 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2206 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2207 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2208 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2210 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2211 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2212 ])
2213 }
2214 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2215 self.isq_layer_regexes(config)
2216 }
2217
2218 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2219 Ok(vec![
2220 Regex::new(r"lm_head\.(weight|bias)$")?,
2221 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2223 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2224 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2225 ])
2226 }
2227 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2228 self.isq_layer_regexes_moqe(config)
2229 }
2230}
2231
2232impl DeviceMappedModelLoader for Phi3_5MoELoader {
2233 fn mapped_max_act_size_elems(
2234 &self,
2235 config: &str,
2236 params: &AutoDeviceMapParams,
2237 prompt_chunksize: usize,
2238 ) -> Result<usize> {
2239 let AutoDeviceMapParams::Text {
2240 max_seq_len: _,
2241 max_batch_size,
2242 } = params
2243 else {
2244 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2245 };
2246
2247 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2248
2249 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2250 }
2251 fn non_mapped_max_act_size_elems(
2252 &self,
2253 _config: &str,
2254 _params: &AutoDeviceMapParams,
2255 ) -> Result<usize> {
2256 Ok(0)
2257 }
2258
2259 fn non_mapped_size_in_bytes(
2260 &self,
2261 config: &str,
2262 dtype: DType,
2263 weight_pack_factor: usize,
2264 _matformer_config: Option<&MatformerSliceConfig>,
2265 ) -> Result<usize> {
2266 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2267
2268 let elems = {
2269 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2270 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2272 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2273 } else {
2274 0
2275 };
2276 let norm = cfg.hidden_size;
2277 embed_tokens + lm_head + norm
2278 };
2279 Ok(elems * dtype.size_in_bytes())
2280 }
2281
2282 fn layer_sizes_in_bytes(
2283 &self,
2284 config: &str,
2285 dtype: DType,
2286 weight_pack_factor: usize,
2287 _matformer_config: Option<&MatformerSliceConfig>,
2288 ) -> Result<Vec<usize>> {
2289 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2290
2291 let per_layer_elems = {
2292 let input_layernorm = cfg.hidden_size;
2293 let post_attention_layernorm = cfg.hidden_size;
2294
2295 let size_in = cfg.hidden_size;
2296 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2297 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2298 let q_proj =
2299 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2300 let k_proj =
2301 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2302 let v_proj =
2303 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2304 let o_proj =
2305 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2306
2307 let moe_block = {
2308 let gate = cfg.hidden_size * cfg.num_local_experts;
2309 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2311 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2312 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2313 gate + cfg.num_local_experts * w1
2314 + cfg.num_local_experts * w2
2315 + cfg.num_local_experts * w3
2316 };
2317
2318 input_layernorm
2319 + post_attention_layernorm
2320 + q_proj
2321 + k_proj
2322 + v_proj
2323 + o_proj
2324 + moe_block
2325 };
2326 Ok(vec![
2327 per_layer_elems * dtype.size_in_bytes();
2328 cfg.num_hidden_layers
2329 ])
2330 }
2331
2332 fn num_layers(&self, config: &str) -> Result<usize> {
2333 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2334
2335 Ok(cfg.num_hidden_layers)
2336 }
2337
2338 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2339 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2340
2341 let cfg = ModelConfigMetadata {
2342 max_seq_len: cfg.max_position_embeddings,
2343 num_layers: cfg.num_hidden_layers,
2344 hidden_size: cfg.hidden_size,
2345 num_kv_heads: cfg.num_key_value_heads,
2346 num_attn_heads: cfg.num_attention_heads,
2347 sliding_window: cfg.sliding_window,
2348 k_head_dim: cfg.head_dim(),
2349 v_head_dim: cfg.head_dim(),
2350 };
2351
2352 Ok(Box::new(cfg))
2353 }
2354}
2355
2356pub struct DeepSeekV2Loader;
2360
2361impl NormalModelLoader for DeepSeekV2Loader {
2362 fn load(
2363 &self,
2364 config: &str,
2365 vb: ShardedVarBuilder,
2366 normal_loading_metadata: NormalLoadingMetadata,
2367 attention_mechanism: AttentionImplementation,
2368 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2369 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2370
2371 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2372 &cfg,
2373 vb,
2374 self.is_gptx(config)?,
2375 normal_loading_metadata,
2376 attention_mechanism,
2377 )?))
2378 }
2379 fn load_xlora(
2380 &self,
2381 _config: &str,
2382 _vb: ShardedVarBuilder,
2383 _lora_config: &[((String, String), LoraConfig)],
2384 _xlora_config: Option<XLoraConfig>,
2385 _xlora_ordering: Ordering,
2386 _normal_loading_metadata: NormalLoadingMetadata,
2387 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2388 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2389 todo!()
2390 }
2391 fn is_gptx(&self, _: &str) -> Result<bool> {
2392 Ok(true)
2393 }
2394 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2395 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2396 Ok(Box::new(cfg))
2397 }
2398}
2399
2400impl IsqModelLoader for DeepSeekV2Loader {
2401 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2402 let mut data = vec![
2403 Regex::new(r"lm_head\.(weight|bias)$")?,
2404 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2406 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2407 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2408 ];
2409 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2410 if cfg.q_lora_rank.is_some() {
2411 data.extend(vec![
2412 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2413 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2414 ]);
2415 } else {
2416 data.push(Regex::new(
2417 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2418 )?);
2419 }
2420 for layer_idx in 0..cfg.num_hidden_layers {
2421 if cfg.n_routed_experts.is_some()
2422 && layer_idx >= cfg.first_k_dense_replace
2423 && layer_idx % cfg.moe_layer_freq == 0
2424 {
2425 for i in 0..cfg.n_routed_experts.unwrap() {
2426 data.extend(vec![
2427 Regex::new(&format!(
2428 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2429 ))?,
2430 Regex::new(&format!(
2431 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2432 ))?,
2433 Regex::new(&format!(
2434 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2435 ))?,
2436 ]);
2437 }
2438 if cfg.n_shared_experts.is_some() {
2439 data.extend(vec![
2440 Regex::new(&format!(
2441 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2442 ))?,
2443 Regex::new(&format!(
2444 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2445 ))?,
2446 Regex::new(&format!(
2447 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2448 ))?,
2449 ]);
2450 }
2451 } else {
2452 data.extend(vec![
2453 Regex::new(&format!(
2454 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2455 ))?,
2456 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2457 Regex::new(&format!(
2458 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2459 ))?,
2460 ]);
2461 };
2462 }
2463 Ok(data)
2464 }
2465 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2466 self.isq_layer_regexes(config)
2467 }
2468
2469 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2470 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2471 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2472 for layer_idx in 0..cfg.num_hidden_layers {
2473 if cfg.n_routed_experts.is_some()
2474 && layer_idx >= cfg.first_k_dense_replace
2475 && layer_idx % cfg.moe_layer_freq == 0
2476 {
2477 for i in 0..cfg.n_routed_experts.unwrap() {
2478 data.extend(vec![
2479 Regex::new(&format!(
2480 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2481 ))?,
2482 Regex::new(&format!(
2483 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2484 ))?,
2485 Regex::new(&format!(
2486 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2487 ))?,
2488 ]);
2489 }
2490 if cfg.n_shared_experts.is_some() {
2491 data.extend(vec![
2492 Regex::new(&format!(
2493 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2494 ))?,
2495 Regex::new(&format!(
2496 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2497 ))?,
2498 Regex::new(&format!(
2499 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2500 ))?,
2501 ]);
2502 }
2503 } else {
2504 data.extend(vec![
2505 Regex::new(&format!(
2506 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2507 ))?,
2508 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2509 Regex::new(&format!(
2510 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2511 ))?,
2512 ]);
2513 };
2514 }
2515 Ok(data)
2516 }
2517 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2518 self.isq_layer_regexes_moqe(config)
2519 }
2520}
2521
2522impl DeviceMappedModelLoader for DeepSeekV2Loader {
2523 fn mapped_max_act_size_elems(
2524 &self,
2525 config: &str,
2526 params: &AutoDeviceMapParams,
2527 prompt_chunksize: usize,
2528 ) -> Result<usize> {
2529 let AutoDeviceMapParams::Text {
2530 max_seq_len: _,
2531 max_batch_size,
2532 } = params
2533 else {
2534 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2535 };
2536
2537 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2538
2539 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2540 }
2541 fn non_mapped_max_act_size_elems(
2542 &self,
2543 _config: &str,
2544 _params: &AutoDeviceMapParams,
2545 ) -> Result<usize> {
2546 Ok(0)
2547 }
2548
2549 fn non_mapped_size_in_bytes(
2550 &self,
2551 config: &str,
2552 dtype: DType,
2553 weight_pack_factor: usize,
2554 _matformer_config: Option<&MatformerSliceConfig>,
2555 ) -> Result<usize> {
2556 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2557 let elems = {
2558 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2559 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2561 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2562 } else {
2563 0
2564 };
2565 let norm = cfg.hidden_size;
2566 embed_tokens + lm_head + norm
2567 };
2568 Ok(elems * dtype.size_in_bytes())
2569 }
2570
2571 fn layer_sizes_in_bytes(
2572 &self,
2573 config: &str,
2574 dtype: DType,
2575 weight_pack_factor: usize,
2576 _matformer_config: Option<&MatformerSliceConfig>,
2577 ) -> Result<Vec<usize>> {
2578 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2579 let mut per_layer_elems = Vec::new();
2580
2581 for layer_idx in 0..cfg.num_hidden_layers {
2582 let input_layernorm = cfg.hidden_size;
2583 let post_attention_layernorm = cfg.hidden_size;
2584
2585 let q_proj = match cfg.q_lora_rank {
2586 Some(lora_rank) => {
2587 let a = cfg.hidden_size * lora_rank;
2588 let norm = lora_rank;
2589 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2590 a + norm + b
2591 }
2592 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2593 };
2594 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2595 / weight_pack_factor
2596 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2597 let kv_a_layernorm = cfg.kv_lora_rank;
2598 let kv_b_proj = cfg.kv_lora_rank
2599 * cfg.num_attention_heads
2600 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2601 / weight_pack_factor;
2602 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2603 / weight_pack_factor
2604 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2605
2606 let moe_block = {
2607 let mut sum = 0;
2608 if cfg.n_routed_experts.is_some()
2609 && layer_idx >= cfg.first_k_dense_replace
2610 && layer_idx % cfg.moe_layer_freq == 0
2611 {
2612 let h_size = cfg.hidden_size;
2613 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2614 * cfg.n_routed_experts.unwrap();
2615 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2616 * cfg.n_routed_experts.unwrap();
2617 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2618 * cfg.n_routed_experts.unwrap();
2619 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2620 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2621 / weight_pack_factor;
2622 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2623 / weight_pack_factor;
2624 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2625 / weight_pack_factor;
2626 gate_proj + up_proj + down_proj
2627 } else {
2628 0
2629 };
2630 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2631 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2632 } else {
2633 let h_size = cfg.hidden_size;
2634 let i_size = cfg.intermediate_size;
2635 let gate_proj = h_size * i_size / weight_pack_factor;
2636 let up_proj = h_size * i_size / weight_pack_factor;
2637 let down_proj = i_size * h_size / weight_pack_factor;
2638 sum += gate_proj + up_proj + down_proj;
2639 }
2640 sum
2641 };
2642
2643 per_layer_elems.push(
2644 input_layernorm
2645 + post_attention_layernorm
2646 + q_proj
2647 + kv_a_layernorm
2648 + kv_a_proj_with_mqa
2649 + kv_b_proj
2650 + o_proj
2651 + moe_block,
2652 );
2653 }
2654
2655 Ok(per_layer_elems
2656 .into_iter()
2657 .map(|x| x * dtype.size_in_bytes())
2658 .collect())
2659 }
2660
2661 fn num_layers(&self, config: &str) -> Result<usize> {
2662 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2663 Ok(cfg.num_hidden_layers)
2664 }
2665
2666 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2667 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2668
2669 let cfg = ModelConfigMetadata {
2670 max_seq_len: cfg.max_position_embeddings,
2671 num_layers: cfg.num_hidden_layers,
2672 hidden_size: cfg.hidden_size,
2673 num_kv_heads: cfg.num_attention_heads,
2674 num_attn_heads: cfg.num_attention_heads,
2675 sliding_window: None,
2676 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2677 v_head_dim: cfg.v_head_dim,
2678 };
2679
2680 Ok(Box::new(cfg))
2681 }
2682}
2683
2684pub struct DeepSeekV3Loader;
2688
2689impl NormalModelLoader for DeepSeekV3Loader {
2690 fn load(
2691 &self,
2692 config: &str,
2693 vb: ShardedVarBuilder,
2694 normal_loading_metadata: NormalLoadingMetadata,
2695 attention_mechanism: AttentionImplementation,
2696 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2697 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2698 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2699 &cfg,
2700 vb,
2701 self.is_gptx(config)?,
2702 normal_loading_metadata,
2703 attention_mechanism,
2704 )?))
2705 }
2706 fn load_xlora(
2707 &self,
2708 _config: &str,
2709 _vb: ShardedVarBuilder,
2710 _lora_config: &[((String, String), LoraConfig)],
2711 _xlora_config: Option<XLoraConfig>,
2712 _xlora_ordering: Ordering,
2713 _normal_loading_metadata: NormalLoadingMetadata,
2714 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2715 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2716 todo!()
2717 }
2718 fn is_gptx(&self, _: &str) -> Result<bool> {
2719 Ok(true)
2720 }
2721 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2722 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2723 Ok(Box::new(cfg))
2724 }
2725}
2726
2727impl IsqModelLoader for DeepSeekV3Loader {
2728 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2729 let mut data = vec![
2730 Regex::new(r"lm_head\.(weight|bias)$")?,
2731 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2733 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2734 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2735 ];
2736 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2737 if cfg.q_lora_rank.is_some() {
2738 data.extend(vec![
2739 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2740 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2741 ]);
2742 } else {
2743 data.push(Regex::new(
2744 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2745 )?);
2746 }
2747 for layer_idx in 0..cfg.num_hidden_layers {
2748 if cfg.n_routed_experts.is_some()
2749 && layer_idx >= cfg.first_k_dense_replace
2750 && layer_idx % cfg.moe_layer_freq == 0
2751 {
2752 for i in 0..cfg.n_routed_experts.unwrap() {
2753 data.extend(vec![
2754 Regex::new(&format!(
2755 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2756 ))?,
2757 Regex::new(&format!(
2758 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2759 ))?,
2760 Regex::new(&format!(
2761 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2762 ))?,
2763 ]);
2764 }
2765 if cfg.n_shared_experts.is_some() {
2766 data.extend(vec![
2767 Regex::new(&format!(
2768 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2769 ))?,
2770 Regex::new(&format!(
2771 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2772 ))?,
2773 Regex::new(&format!(
2774 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2775 ))?,
2776 ]);
2777 }
2778 } else {
2779 data.extend(vec![
2780 Regex::new(&format!(
2781 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2782 ))?,
2783 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2784 Regex::new(&format!(
2785 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2786 ))?,
2787 ]);
2788 };
2789 }
2790 Ok(data)
2791 }
2792 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2793 self.isq_layer_regexes(config)
2794 }
2795
2796 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2797 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2798 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2799 for layer_idx in 0..cfg.num_hidden_layers {
2800 if cfg.n_routed_experts.is_some()
2801 && layer_idx >= cfg.first_k_dense_replace
2802 && layer_idx % cfg.moe_layer_freq == 0
2803 {
2804 for i in 0..cfg.n_routed_experts.unwrap() {
2805 data.extend(vec![
2806 Regex::new(&format!(
2807 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2808 ))?,
2809 Regex::new(&format!(
2810 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2811 ))?,
2812 Regex::new(&format!(
2813 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2814 ))?,
2815 ]);
2816 }
2817 if cfg.n_shared_experts.is_some() {
2818 data.extend(vec![
2819 Regex::new(&format!(
2820 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2821 ))?,
2822 Regex::new(&format!(
2823 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2824 ))?,
2825 Regex::new(&format!(
2826 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2827 ))?,
2828 ]);
2829 }
2830 } else {
2831 data.extend(vec![
2832 Regex::new(&format!(
2833 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2834 ))?,
2835 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2836 Regex::new(&format!(
2837 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2838 ))?,
2839 ]);
2840 };
2841 }
2842 Ok(data)
2843 }
2844 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2845 self.isq_layer_regexes_moqe(config)
2846 }
2847}
2848
2849impl DeviceMappedModelLoader for DeepSeekV3Loader {
2850 fn mapped_max_act_size_elems(
2851 &self,
2852 config: &str,
2853 params: &AutoDeviceMapParams,
2854 prompt_chunksize: usize,
2855 ) -> Result<usize> {
2856 let AutoDeviceMapParams::Text {
2857 max_seq_len: _,
2858 max_batch_size,
2859 } = params
2860 else {
2861 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2862 };
2863
2864 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2865
2866 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2867 }
2868 fn non_mapped_max_act_size_elems(
2869 &self,
2870 _config: &str,
2871 _params: &AutoDeviceMapParams,
2872 ) -> Result<usize> {
2873 Ok(0)
2874 }
2875
2876 fn non_mapped_size_in_bytes(
2877 &self,
2878 config: &str,
2879 dtype: DType,
2880 weight_pack_factor: usize,
2881 _matformer_config: Option<&MatformerSliceConfig>,
2882 ) -> Result<usize> {
2883 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2884 let elems = {
2885 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2886 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2888 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2889 } else {
2890 0
2891 };
2892 let norm = cfg.hidden_size;
2893 embed_tokens + lm_head + norm
2894 };
2895 Ok(elems * dtype.size_in_bytes())
2896 }
2897
2898 fn layer_sizes_in_bytes(
2899 &self,
2900 config: &str,
2901 dtype: DType,
2902 weight_pack_factor: usize,
2903 _matformer_config: Option<&MatformerSliceConfig>,
2904 ) -> Result<Vec<usize>> {
2905 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2906 let mut per_layer_elems = Vec::new();
2907
2908 for layer_idx in 0..cfg.num_hidden_layers {
2909 let input_layernorm = cfg.hidden_size;
2910 let post_attention_layernorm = cfg.hidden_size;
2911
2912 let q_proj = match cfg.q_lora_rank {
2913 Some(lora_rank) => {
2914 let a = cfg.hidden_size * lora_rank;
2915 let norm = lora_rank;
2916 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2917 a + norm + b
2918 }
2919 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2920 };
2921 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2922 / weight_pack_factor
2923 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2924 let kv_a_layernorm = cfg.kv_lora_rank;
2925 let kv_b_proj = cfg.kv_lora_rank
2926 * cfg.num_attention_heads
2927 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2928 / weight_pack_factor;
2929 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2930 / weight_pack_factor
2931 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2932
2933 let moe_block = {
2934 let mut sum = 0;
2935 if cfg.n_routed_experts.is_some()
2936 && layer_idx >= cfg.first_k_dense_replace
2937 && layer_idx % cfg.moe_layer_freq == 0
2938 {
2939 let h_size = cfg.hidden_size;
2940 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2941 * cfg.n_routed_experts.unwrap();
2942 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2943 * cfg.n_routed_experts.unwrap();
2944 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2945 * cfg.n_routed_experts.unwrap();
2946 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2947 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2948 / weight_pack_factor;
2949 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2950 / weight_pack_factor;
2951 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2952 / weight_pack_factor;
2953 gate_proj + up_proj + down_proj
2954 } else {
2955 0
2956 };
2957 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2958 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2959 } else {
2960 let h_size = cfg.hidden_size;
2961 let i_size = cfg.intermediate_size;
2962 let gate_proj = h_size * i_size / weight_pack_factor;
2963 let up_proj = h_size * i_size / weight_pack_factor;
2964 let down_proj = i_size * h_size / weight_pack_factor;
2965 sum += gate_proj + up_proj + down_proj;
2966 }
2967 sum
2968 };
2969
2970 per_layer_elems.push(
2971 input_layernorm
2972 + post_attention_layernorm
2973 + q_proj
2974 + kv_a_layernorm
2975 + kv_a_proj_with_mqa
2976 + kv_b_proj
2977 + o_proj
2978 + moe_block,
2979 );
2980 }
2981
2982 Ok(per_layer_elems
2983 .into_iter()
2984 .map(|x| x * dtype.size_in_bytes())
2985 .collect())
2986 }
2987
2988 fn num_layers(&self, config: &str) -> Result<usize> {
2989 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2990 Ok(cfg.num_hidden_layers)
2991 }
2992
2993 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2994 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2995
2996 let cfg = ModelConfigMetadata {
2997 max_seq_len: cfg.max_position_embeddings,
2998 num_layers: cfg.num_hidden_layers,
2999 hidden_size: cfg.hidden_size,
3000 num_kv_heads: cfg.num_attention_heads,
3001 num_attn_heads: cfg.num_attention_heads,
3002 sliding_window: None,
3003 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3004 v_head_dim: cfg.v_head_dim,
3005 };
3006
3007 Ok(Box::new(cfg))
3008 }
3009}
3010
3011pub struct Qwen3Loader;
3015
3016impl NormalModelLoader for Qwen3Loader {
3017 fn load(
3018 &self,
3019 config: &str,
3020 vb: ShardedVarBuilder,
3021 normal_loading_metadata: NormalLoadingMetadata,
3022 attention_mechanism: AttentionImplementation,
3023 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3024 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3025
3026 Ok(Box::new(models::qwen3::Model::new(
3027 &cfg,
3028 vb,
3029 self.is_gptx(config)?,
3030 normal_loading_metadata,
3031 attention_mechanism,
3032 )?))
3033 }
3034 fn load_xlora(
3035 &self,
3036 _config: &str,
3037 _vb: ShardedVarBuilder,
3038 _lora_config: &[((String, String), LoraConfig)],
3039 _xlora_config: Option<XLoraConfig>,
3040 _xlora_ordering: Ordering,
3041 _normal_loading_metadata: NormalLoadingMetadata,
3042 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3043 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3044 todo!()
3045 }
3046 fn is_gptx(&self, _: &str) -> Result<bool> {
3047 Ok(true)
3048 }
3049 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3050 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3051
3052 Ok(Box::new(cfg))
3053 }
3054}
3055
3056impl IsqModelLoader for Qwen3Loader {
3057 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3058 Ok(vec![
3059 Regex::new(r"lm_head\.(weight|bias)$")?,
3060 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3062 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3063 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3064 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3065 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3067 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3068 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3069 ])
3070 }
3071 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3072 self.isq_layer_regexes(config)
3073 }
3074}
3075
3076impl DeviceMappedModelLoader for Qwen3Loader {
3077 fn mapped_max_act_size_elems(
3078 &self,
3079 config: &str,
3080 params: &AutoDeviceMapParams,
3081 prompt_chunksize: usize,
3082 ) -> Result<usize> {
3083 let AutoDeviceMapParams::Text {
3084 max_seq_len: _,
3085 max_batch_size,
3086 } = params
3087 else {
3088 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3089 };
3090
3091 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3092
3093 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3094 }
3095 fn non_mapped_max_act_size_elems(
3096 &self,
3097 _config: &str,
3098 _params: &AutoDeviceMapParams,
3099 ) -> Result<usize> {
3100 Ok(0)
3101 }
3102
3103 fn non_mapped_size_in_bytes(
3104 &self,
3105 config: &str,
3106 dtype: DType,
3107 weight_pack_factor: usize,
3108 _matformer_config: Option<&MatformerSliceConfig>,
3109 ) -> Result<usize> {
3110 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3111 let elems = {
3112 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3113 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3115 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3116 } else {
3117 0
3118 };
3119 let norm = cfg.hidden_size;
3120 embed_tokens + lm_head + norm
3121 };
3122 Ok(elems * dtype.size_in_bytes())
3123 }
3124
3125 fn layer_sizes_in_bytes(
3126 &self,
3127 config: &str,
3128 dtype: DType,
3129 weight_pack_factor: usize,
3130 _matformer_config: Option<&MatformerSliceConfig>,
3131 ) -> Result<Vec<usize>> {
3132 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3133 let per_layer_elems = {
3134 let input_layernorm = cfg.hidden_size;
3135 let post_attention_layernorm = cfg.hidden_size;
3136
3137 let size_in = cfg.hidden_size;
3138 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3139 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3140 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3141 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3142 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3143 let o_proj = size_q * size_in / weight_pack_factor;
3144
3145 let h_size = cfg.hidden_size;
3146 let i_size = cfg.intermediate_size;
3147 let gate_proj = h_size * i_size / weight_pack_factor;
3148 let up_proj = h_size * i_size / weight_pack_factor;
3149 let down_proj = i_size * h_size / weight_pack_factor;
3150
3151 let q_norm = cfg.head_dim();
3152 let k_norm = cfg.head_dim();
3153
3154 input_layernorm
3155 + post_attention_layernorm
3156 + q_proj
3157 + k_proj
3158 + v_proj
3159 + o_proj
3160 + gate_proj
3161 + up_proj
3162 + down_proj
3163 + q_norm
3164 + k_norm
3165 };
3166 Ok(vec![
3167 per_layer_elems * dtype.size_in_bytes();
3168 cfg.num_hidden_layers
3169 ])
3170 }
3171
3172 fn num_layers(&self, config: &str) -> Result<usize> {
3173 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3174 Ok(cfg.num_hidden_layers)
3175 }
3176
3177 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3178 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3179
3180 let cfg = ModelConfigMetadata {
3181 max_seq_len: cfg.max_position_embeddings,
3182 num_layers: cfg.num_hidden_layers,
3183 hidden_size: cfg.hidden_size,
3184 num_kv_heads: cfg.num_key_value_heads,
3185 num_attn_heads: cfg.num_attention_heads,
3186 sliding_window: cfg.sliding_window,
3187 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3188 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3189 };
3190
3191 Ok(Box::new(cfg))
3192 }
3193}
3194
3195pub struct GLM4Loader;
3199
3200impl NormalModelLoader for GLM4Loader {
3201 fn load(
3202 &self,
3203 config: &str,
3204 vb: ShardedVarBuilder,
3205 normal_loading_metadata: NormalLoadingMetadata,
3206 attention_mechanism: AttentionImplementation,
3207 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3208 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3209
3210 Ok(Box::new(models::glm4::Model::new(
3211 &cfg,
3212 vb,
3213 self.is_gptx(config)?,
3214 normal_loading_metadata,
3215 attention_mechanism,
3216 )?))
3217 }
3218 fn load_xlora(
3219 &self,
3220 _config: &str,
3221 _vb: ShardedVarBuilder,
3222 _lora_config: &[((String, String), LoraConfig)],
3223 _xlora_config: Option<XLoraConfig>,
3224 _xlora_ordering: Ordering,
3225 _normal_loading_metadata: NormalLoadingMetadata,
3226 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3227 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3228 todo!()
3229 }
3230 fn is_gptx(&self, _: &str) -> Result<bool> {
3231 Ok(true)
3232 }
3233 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3234 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3235
3236 Ok(Box::new(cfg))
3237 }
3238}
3239
3240impl IsqModelLoader for GLM4Loader {
3241 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3242 Ok(vec![
3243 Regex::new(r"lm_head\.(weight|bias)$")?,
3244 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3246 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3247 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3248 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3249 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3251 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3252 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3253 ])
3254 }
3255 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3256 self.isq_layer_regexes(config)
3257 }
3258}
3259
3260impl DeviceMappedModelLoader for GLM4Loader {
3261 fn mapped_max_act_size_elems(
3262 &self,
3263 config: &str,
3264 params: &AutoDeviceMapParams,
3265 prompt_chunksize: usize,
3266 ) -> Result<usize> {
3267 let AutoDeviceMapParams::Text {
3268 max_seq_len: _,
3269 max_batch_size,
3270 } = params
3271 else {
3272 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3273 };
3274
3275 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3276
3277 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3278 }
3279 fn non_mapped_max_act_size_elems(
3280 &self,
3281 _config: &str,
3282 _params: &AutoDeviceMapParams,
3283 ) -> Result<usize> {
3284 Ok(0)
3285 }
3286
3287 fn non_mapped_size_in_bytes(
3288 &self,
3289 config: &str,
3290 dtype: DType,
3291 weight_pack_factor: usize,
3292 _matformer_config: Option<&MatformerSliceConfig>,
3293 ) -> Result<usize> {
3294 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3295 let elems = {
3296 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3297 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3299 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3300 } else {
3301 0
3302 };
3303 let norm = cfg.hidden_size;
3304 embed_tokens + lm_head + norm
3305 };
3306 Ok(elems * dtype.size_in_bytes())
3307 }
3308
3309 fn layer_sizes_in_bytes(
3310 &self,
3311 config: &str,
3312 dtype: DType,
3313 weight_pack_factor: usize,
3314 _matformer_config: Option<&MatformerSliceConfig>,
3315 ) -> Result<Vec<usize>> {
3316 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3317 let per_layer_elems = {
3318 let input_layernorm = cfg.hidden_size;
3319 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3322 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3323 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3324 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3325 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3326 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3327 let o_proj = size_q * size_in / weight_pack_factor;
3328
3329 let h_size = cfg.hidden_size;
3330 let i_size = cfg.intermediate_size;
3331 let gate_proj = h_size * i_size / weight_pack_factor;
3332 let up_proj = h_size * i_size / weight_pack_factor;
3333 let down_proj = i_size * h_size / weight_pack_factor;
3334
3335 input_layernorm
3336 + post_attention_layernorm
3337 + q_proj
3338 + k_proj
3339 + v_proj
3340 + o_proj
3341 + gate_proj
3342 + up_proj
3343 + down_proj
3344 };
3345 Ok(vec![
3346 per_layer_elems * dtype.size_in_bytes();
3347 cfg.num_hidden_layers
3348 ])
3349 }
3350
3351 fn num_layers(&self, config: &str) -> Result<usize> {
3352 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3353 Ok(cfg.num_hidden_layers)
3354 }
3355
3356 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3357 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358
3359 let cfg = ModelConfigMetadata {
3360 max_seq_len: cfg.max_position_embeddings,
3361 num_layers: cfg.num_hidden_layers,
3362 hidden_size: cfg.hidden_size,
3363 num_kv_heads: cfg.num_key_value_heads,
3364 num_attn_heads: cfg.num_attention_heads,
3365 sliding_window: cfg.sliding_window,
3366 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3367 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3368 };
3369
3370 Ok(Box::new(cfg))
3371 }
3372}
3373
3374pub struct Qwen3MoELoader;
3378
3379impl NormalModelLoader for Qwen3MoELoader {
3380 fn load(
3381 &self,
3382 config: &str,
3383 vb: ShardedVarBuilder,
3384 normal_loading_metadata: NormalLoadingMetadata,
3385 attention_mechanism: AttentionImplementation,
3386 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3387 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3388
3389 Ok(Box::new(models::qwen3_moe::Model::new(
3390 &cfg,
3391 vb,
3392 self.is_gptx(config)?,
3393 normal_loading_metadata,
3394 attention_mechanism,
3395 )?))
3396 }
3397 fn load_xlora(
3398 &self,
3399 _config: &str,
3400 _vb: ShardedVarBuilder,
3401 _lora_config: &[((String, String), LoraConfig)],
3402 _xlora_config: Option<XLoraConfig>,
3403 _xlora_ordering: Ordering,
3404 _normal_loading_metadata: NormalLoadingMetadata,
3405 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3406 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3407 todo!()
3408 }
3409 fn is_gptx(&self, _: &str) -> Result<bool> {
3410 Ok(true)
3411 }
3412 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3413 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3414
3415 Ok(Box::new(cfg))
3416 }
3417}
3418
3419impl IsqModelLoader for Qwen3MoELoader {
3420 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3421 Ok(vec![
3422 Regex::new(r"lm_head\.(weight|bias)$")?,
3423 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3425 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3426 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3427 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3428 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3430 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3431 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3432 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3434 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3435 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3436 ])
3437 }
3438 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3439 self.isq_layer_regexes(config)
3440 }
3441 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3442 self.isq_layer_regexes_moqe(config)
3443 }
3444}
3445
3446impl DeviceMappedModelLoader for Qwen3MoELoader {
3447 fn mapped_max_act_size_elems(
3448 &self,
3449 config: &str,
3450 params: &AutoDeviceMapParams,
3451 prompt_chunksize: usize,
3452 ) -> Result<usize> {
3453 let AutoDeviceMapParams::Text {
3454 max_seq_len: _,
3455 max_batch_size,
3456 } = params
3457 else {
3458 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3459 };
3460
3461 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3462
3463 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3464 }
3465 fn non_mapped_max_act_size_elems(
3466 &self,
3467 _config: &str,
3468 _params: &AutoDeviceMapParams,
3469 ) -> Result<usize> {
3470 Ok(0)
3471 }
3472
3473 fn non_mapped_size_in_bytes(
3474 &self,
3475 config: &str,
3476 dtype: DType,
3477 weight_pack_factor: usize,
3478 _matformer_config: Option<&MatformerSliceConfig>,
3479 ) -> Result<usize> {
3480 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3481 let elems = {
3482 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3483 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3485 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3486 } else {
3487 0
3488 };
3489 let norm = cfg.hidden_size;
3490 embed_tokens + lm_head + norm
3491 };
3492 Ok(elems * dtype.size_in_bytes())
3493 }
3494
3495 fn layer_sizes_in_bytes(
3496 &self,
3497 config: &str,
3498 dtype: DType,
3499 weight_pack_factor: usize,
3500 _matformer_config: Option<&MatformerSliceConfig>,
3501 ) -> Result<Vec<usize>> {
3502 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3503
3504 let mut layer_sizes_in_bytes = Vec::new();
3505 for layer_idx in 0..cfg.num_hidden_layers {
3506 let input_layernorm = cfg.hidden_size;
3507 let post_attention_layernorm = cfg.hidden_size;
3508
3509 let size_in = cfg.hidden_size;
3510 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3511 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3512 let q_proj = size_in * size_q / weight_pack_factor;
3513 let k_proj = size_in * size_kv / weight_pack_factor;
3514 let v_proj = size_in * size_kv / weight_pack_factor;
3515 let o_proj = size_q * size_in / weight_pack_factor;
3516
3517 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3518 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3519 {
3520 let gate_size = cfg.hidden_size * cfg.num_experts;
3521 let expert_size = {
3522 let h_size = cfg.hidden_size;
3523 let i_size = cfg.moe_intermediate_size;
3524 let gate_proj = h_size * i_size / weight_pack_factor;
3525 let up_proj = h_size * i_size / weight_pack_factor;
3526 let down_proj = i_size * h_size / weight_pack_factor;
3527 gate_proj + up_proj + down_proj
3528 };
3529 expert_size * cfg.num_experts + gate_size
3530 } else {
3531 let h_size = cfg.hidden_size;
3532 let i_size = cfg.intermediate_size;
3533 let gate_proj = h_size * i_size / weight_pack_factor;
3534 let up_proj = h_size * i_size / weight_pack_factor;
3535 let down_proj = i_size * h_size / weight_pack_factor;
3536 gate_proj + up_proj + down_proj
3537 };
3538
3539 let q_norm = cfg.head_dim();
3540 let k_norm = cfg.head_dim();
3541
3542 let size_elems = input_layernorm
3543 + post_attention_layernorm
3544 + q_proj
3545 + k_proj
3546 + v_proj
3547 + o_proj
3548 + mlp_size
3549 + q_norm
3550 + k_norm;
3551
3552 let size_in_bytes = size_elems * dtype.size_in_bytes();
3553 layer_sizes_in_bytes.push(size_in_bytes);
3554 }
3555
3556 Ok(layer_sizes_in_bytes)
3557 }
3558
3559 fn num_layers(&self, config: &str) -> Result<usize> {
3560 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3561 Ok(cfg.num_hidden_layers)
3562 }
3563
3564 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3565 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3566
3567 let cfg = ModelConfigMetadata {
3568 max_seq_len: cfg.max_position_embeddings,
3569 num_layers: cfg.num_hidden_layers,
3570 hidden_size: cfg.hidden_size,
3571 num_kv_heads: cfg.num_key_value_heads,
3572 num_attn_heads: cfg.num_attention_heads,
3573 sliding_window: cfg.sliding_window,
3574 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3575 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3576 };
3577
3578 Ok(Box::new(cfg))
3579 }
3580}
3581
3582pub struct SmolLm3Loader;
3588
3589impl NormalModelLoader for SmolLm3Loader {
3590 fn load(
3591 &self,
3592 config: &str,
3593 vb: ShardedVarBuilder,
3594 normal_loading_metadata: NormalLoadingMetadata,
3595 attention_mechanism: AttentionImplementation,
3596 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3597 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3598
3599 Ok(Box::new(models::smollm3::SmolLm3::new(
3600 &cfg,
3601 vb,
3602 self.is_gptx(config)?,
3603 normal_loading_metadata,
3604 attention_mechanism,
3605 )?))
3606 }
3607 fn load_xlora(
3608 &self,
3609 _config: &str,
3610 _vb: ShardedVarBuilder,
3611 _lora_config: &[((String, String), LoraConfig)],
3612 _xlora_config: Option<XLoraConfig>,
3613 _xlora_ordering: Ordering,
3614 _normal_loading_metadata: NormalLoadingMetadata,
3615 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3616 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3617 todo!()
3618 }
3619 fn is_gptx(&self, _: &str) -> Result<bool> {
3620 Ok(true)
3621 }
3622 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3623 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3624 Ok(Box::new(cfg))
3625 }
3626}
3627
3628impl IsqModelLoader for SmolLm3Loader {
3629 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3630 Ok(vec![
3631 Regex::new(r"lm_head\.(weight|bias)$")?,
3632 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3634 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3635 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3636 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3637 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3639 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3640 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3641 ])
3642 }
3643 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3644 self.isq_layer_regexes(config)
3645 }
3646}
3647
3648impl DeviceMappedModelLoader for SmolLm3Loader {
3649 fn mapped_max_act_size_elems(
3650 &self,
3651 config: &str,
3652 params: &AutoDeviceMapParams,
3653 prompt_chunksize: usize,
3654 ) -> Result<usize> {
3655 let AutoDeviceMapParams::Text {
3656 max_seq_len: _,
3657 max_batch_size,
3658 } = params
3659 else {
3660 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3661 };
3662
3663 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3664
3665 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3666 }
3667 fn non_mapped_max_act_size_elems(
3668 &self,
3669 _config: &str,
3670 _params: &AutoDeviceMapParams,
3671 ) -> Result<usize> {
3672 Ok(0)
3673 }
3674
3675 fn non_mapped_size_in_bytes(
3676 &self,
3677 config: &str,
3678 dtype: DType,
3679 weight_pack_factor: usize,
3680 _matformer_config: Option<&MatformerSliceConfig>,
3681 ) -> Result<usize> {
3682 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3683
3684 let elems = {
3685 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3686 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3688 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3689 } else {
3690 0
3691 };
3692 let norm = cfg.hidden_size;
3693 embed_tokens + lm_head + norm
3694 };
3695 Ok(elems * dtype.size_in_bytes())
3696 }
3697
3698 fn layer_sizes_in_bytes(
3699 &self,
3700 config: &str,
3701 dtype: DType,
3702 weight_pack_factor: usize,
3703 _matformer_config: Option<&MatformerSliceConfig>,
3704 ) -> Result<Vec<usize>> {
3705 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3706
3707 let per_layer_elems = {
3708 let input_layernorm = cfg.hidden_size;
3709 let post_attention_layernorm = cfg.hidden_size;
3710
3711 let size_in = cfg.hidden_size;
3712 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3713 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3714 let q_proj = size_in * size_q / weight_pack_factor;
3715 let k_proj = size_in * size_kv / weight_pack_factor;
3716 let v_proj = size_in * size_kv / weight_pack_factor;
3717 let o_proj = size_q * size_in / weight_pack_factor;
3718
3719 let h_size = cfg.hidden_size;
3720 let i_size = cfg.intermediate_size;
3721 let gate_proj = h_size * i_size / weight_pack_factor;
3722 let up_proj = h_size * i_size / weight_pack_factor;
3723 let down_proj = i_size * h_size / weight_pack_factor;
3724
3725 input_layernorm
3726 + post_attention_layernorm
3727 + q_proj
3728 + k_proj
3729 + v_proj
3730 + o_proj
3731 + gate_proj
3732 + up_proj
3733 + down_proj
3734 };
3735 Ok(vec![
3736 per_layer_elems * dtype.size_in_bytes();
3737 cfg.num_hidden_layers
3738 ])
3739 }
3740
3741 fn num_layers(&self, config: &str) -> Result<usize> {
3742 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3743
3744 Ok(cfg.num_hidden_layers)
3745 }
3746 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3747 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3748
3749 let cfg = ModelConfigMetadata {
3750 max_seq_len: cfg.max_position_embeddings,
3751 num_layers: cfg.num_hidden_layers,
3752 hidden_size: cfg.hidden_size,
3753 num_kv_heads: cfg.num_key_value_heads,
3754 num_attn_heads: cfg.num_attention_heads,
3755 sliding_window: None,
3756 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3757 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3758 };
3759
3760 Ok(Box::new(cfg))
3761 }
3762}