1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "qwen3moe")]
175 Qwen3Moe,
176 #[serde(rename = "smollm3")]
177 SmolLm3,
178 #[serde(rename = "granitemoehybrid")]
179 GraniteMoeHybrid,
180}
181
182impl NormalLoaderType {
184 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
185 match name {
186 "MistralForCausalLM" => Ok(Self::Mistral),
187 "MixtralForCausalLM" => Ok(Self::Mixtral),
188 "GemmaForCausalLM" => Ok(Self::Gemma),
189 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
190 "PhiForCausalLM" => Ok(Self::Phi2),
191 "Phi3ForCausalLM" => Ok(Self::Phi3),
192 "LlamaForCausalLM" => Ok(Self::Llama),
193 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
194 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
195 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
196 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
197 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
198 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
199 "Glm4ForCausalLM" => Ok(Self::GLM4),
200 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
201 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
202 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
203 other => anyhow::bail!(
204 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
205 ),
206 }
207 }
208}
209
210impl FromStr for NormalLoaderType {
211 type Err = String;
212 fn from_str(s: &str) -> Result<Self, Self::Err> {
213 match s {
214 "mistral" => Ok(Self::Mistral),
215 "gemma" => Ok(Self::Gemma),
216 "mixtral" => Ok(Self::Mixtral),
217 "llama" => Ok(Self::Llama),
218 "phi2" => Ok(Self::Phi2),
219 "phi3" => Ok(Self::Phi3),
220 "qwen2" => Ok(Self::Qwen2),
221 "gemma2" => Ok(Self::Gemma2),
222 "starcoder2" => Ok(Self::Starcoder2),
223 "phi3.5moe" => Ok(Self::Phi3_5MoE),
224 "deepseekv2" => Ok(Self::DeepSeekV2),
225 "deepseekv3" => Ok(Self::DeepSeekV3),
226 "qwen3" => Ok(Self::Qwen3),
227 "glm4" => Ok(Self::GLM4),
228 "qwen3moe" => Ok(Self::Qwen3Moe),
229 "smollm3" => Ok(Self::SmolLm3),
230 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
231 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`, `granitemoehybrid`.")),
232 }
233 }
234}
235
236impl Display for NormalLoaderType {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 Self::Gemma => write!(f, "gemma"),
240 Self::Gemma2 => write!(f, "gemma2"),
241 Self::Llama => write!(f, "llama"),
242 Self::Mistral => write!(f, "mistral"),
243 Self::Mixtral => write!(f, "mixtral"),
244 Self::Phi2 => write!(f, "phi2"),
245 Self::Phi3 => write!(f, "phi3"),
246 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
247 Self::Qwen2 => write!(f, "qwen2"),
248 Self::Starcoder2 => write!(f, "starcoder2"),
249 Self::DeepSeekV2 => write!(f, "deepseekv2"),
250 Self::DeepSeekV3 => write!(f, "deepseekv3"),
251 Self::Qwen3 => write!(f, "qwen3"),
252 Self::GLM4 => write!(f, "glm4"),
253 Self::Qwen3Moe => write!(f, "qwen3moe"),
254 Self::SmolLm3 => write!(f, "smollm3"),
255 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
256 }
257 }
258}
259
260macro_rules! bias_if {
261 ($cond:expr, $size:expr) => {
262 if $cond {
263 $size
264 } else {
265 0
266 }
267 };
268}
269
270pub struct AutoNormalLoader;
272
273#[derive(Deserialize)]
274struct AutoNormalLoaderConfig {
275 architectures: Vec<String>,
276}
277
278impl AutoNormalLoader {
279 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
280 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
281 if auto_cfg.architectures.len() != 1 {
282 anyhow::bail!("Expected to have one name for `architectures` config field.")
283 }
284
285 let name = &auto_cfg.architectures[0];
286
287 let tp = NormalLoaderType::from_causal_lm_name(name)?;
288
289 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
290
291 match tp {
292 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
293 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
294 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
295 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
296 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
297 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
298 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
299 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
300 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
301 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
302 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
303 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
304 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
305 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
306 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
307 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
308 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
309 }
310 }
311}
312
313impl NormalModelLoader for AutoNormalLoader {
314 fn load(
315 &self,
316 config: &str,
317 vb: ShardedVarBuilder,
318 normal_loading_metadata: NormalLoadingMetadata,
319 attention_mechanism: AttentionImplementation,
320 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
321 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
322 }
323 fn load_xlora(
324 &self,
325 config: &str,
326 vb: ShardedVarBuilder,
327 lora_config: &[((String, String), LoraConfig)],
328 xlora_config: Option<XLoraConfig>,
329 xlora_ordering: Ordering,
330 normal_loading_metadata: NormalLoadingMetadata,
331 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
332 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
333 Self::get_loader(config)?.load_xlora(
334 config,
335 vb,
336 lora_config,
337 xlora_config,
338 xlora_ordering,
339 normal_loading_metadata,
340 preload_adapters,
341 )
342 }
343 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
344 Self::get_loader(config)?.get_config_repr(config)
345 }
346 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
347 Self::get_loader(config)?.supports_paged_attention(config)
348 }
349 fn is_gptx(&self, config: &str) -> Result<bool> {
350 Self::get_loader(config)?.is_gptx(config)
351 }
352}
353
354impl IsqModelLoader for AutoNormalLoader {
355 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
356 Self::get_loader(config)?.immediate_isq_predicates(config)
357 }
358 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
359 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
360 }
361 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
362 Self::get_loader(config)?.isq_layer_regexes(config)
363 }
364 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
365 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
366 }
367}
368
369impl DeviceMappedModelLoader for AutoNormalLoader {
370 fn non_mapped_size_in_bytes(
371 &self,
372 config: &str,
373 dtype: DType,
374 weight_pack_factor: usize,
375 _matformer_config: Option<&MatformerSliceConfig>,
376 ) -> Result<usize> {
377 Self::get_loader(config)?.non_mapped_size_in_bytes(
378 config,
379 dtype,
380 weight_pack_factor,
381 _matformer_config,
382 )
383 }
384 fn num_layers(&self, config: &str) -> Result<usize> {
385 Self::get_loader(config)?.num_layers(config)
386 }
387 fn layer_sizes_in_bytes(
388 &self,
389 config: &str,
390 dtype: DType,
391 weight_pack_factor: usize,
392 _matformer_config: Option<&MatformerSliceConfig>,
393 ) -> Result<Vec<usize>> {
394 Self::get_loader(config)?.layer_sizes_in_bytes(
395 config,
396 dtype,
397 weight_pack_factor,
398 _matformer_config,
399 )
400 }
401 fn mapped_max_act_size_elems(
402 &self,
403 config: &str,
404 params: &super::AutoDeviceMapParams,
405 ) -> Result<usize> {
406 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
407 }
408 fn non_mapped_max_act_size_elems(
409 &self,
410 _config: &str,
411 _params: &AutoDeviceMapParams,
412 ) -> Result<usize> {
413 Ok(0)
414 }
415 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
416 Self::get_loader(config)?.model_config(config)
417 }
418}
419
420pub struct MistralLoader;
423
424impl NormalModelLoader for MistralLoader {
425 fn load(
426 &self,
427 config: &str,
428 vb: ShardedVarBuilder,
429 normal_loading_metadata: NormalLoadingMetadata,
430 attention_mechanism: AttentionImplementation,
431 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
432 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
433 Ok(Box::new(models::mistral::Model::new(
434 &cfg,
435 vb,
436 self.is_gptx(config)?,
437 normal_loading_metadata,
438 attention_mechanism,
439 )?))
440 }
441 fn load_xlora(
442 &self,
443 config: &str,
444 vb: ShardedVarBuilder,
445 lora_config: &[((String, String), LoraConfig)],
446 xlora_config: Option<XLoraConfig>,
447 xlora_ordering: Ordering,
448 normal_loading_metadata: NormalLoadingMetadata,
449 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
450 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
451 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
452 Ok(Box::new(xlora_models::XLoraMistral::new(
453 &cfg,
454 vb,
455 lora_config,
456 xlora_config,
457 xlora_ordering,
458 self.is_gptx(config)?,
459 normal_loading_metadata,
460 preload_adapters,
461 )?))
462 }
463 fn is_gptx(&self, _: &str) -> Result<bool> {
464 Ok(true)
465 }
466 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
467 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
468 Ok(Box::new(cfg))
469 }
470}
471
472impl IsqModelLoader for MistralLoader {
473 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
474 Ok(vec![
475 Regex::new(r"lm_head\.(weight|bias)$")?,
476 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
478 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
479 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
480 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
481 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
483 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
484 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
485 ])
486 }
487 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
488 self.isq_layer_regexes(config)
489 }
490}
491
492impl DeviceMappedModelLoader for MistralLoader {
493 fn mapped_max_act_size_elems(
494 &self,
495 config: &str,
496 params: &AutoDeviceMapParams,
497 ) -> Result<usize> {
498 let AutoDeviceMapParams::Text {
499 max_seq_len,
500 max_batch_size,
501 } = params
502 else {
503 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
504 };
505
506 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
507
508 Ok(
509 max_batch_size
510 * cfg.num_attention_heads
511 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
512 )
513 }
514 fn non_mapped_max_act_size_elems(
515 &self,
516 _config: &str,
517 _params: &AutoDeviceMapParams,
518 ) -> Result<usize> {
519 Ok(0)
520 }
521
522 fn non_mapped_size_in_bytes(
523 &self,
524 config: &str,
525 dtype: DType,
526 weight_pack_factor: usize,
527 _matformer_config: Option<&MatformerSliceConfig>,
528 ) -> Result<usize> {
529 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
530
531 let elems = {
532 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
533 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
535 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
536 } else {
537 0
538 };
539 let norm = cfg.hidden_size;
540 embed_tokens + lm_head + norm
541 };
542 Ok(elems * dtype.size_in_bytes())
543 }
544
545 fn layer_sizes_in_bytes(
546 &self,
547 config: &str,
548 dtype: DType,
549 weight_pack_factor: usize,
550 _matformer_config: Option<&MatformerSliceConfig>,
551 ) -> Result<Vec<usize>> {
552 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
553
554 let per_layer_elems = {
555 let input_layernorm = cfg.hidden_size;
556 let post_attention_layernorm = cfg.hidden_size;
557
558 let size_in = cfg.hidden_size;
559 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
560 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
561 let q_proj = size_in * size_q / weight_pack_factor;
562 let k_proj = size_in * size_kv / weight_pack_factor;
563 let v_proj = size_in * size_kv / weight_pack_factor;
564 let o_proj = size_q * size_in / weight_pack_factor;
565
566 let h_size = cfg.hidden_size;
567 let i_size = cfg.intermediate_size;
568 let gate_proj = h_size * i_size / weight_pack_factor;
569 let up_proj = h_size * i_size / weight_pack_factor;
570 let down_proj = i_size * h_size / weight_pack_factor;
571
572 input_layernorm
573 + post_attention_layernorm
574 + q_proj
575 + k_proj
576 + v_proj
577 + o_proj
578 + gate_proj
579 + up_proj
580 + down_proj
581 };
582 Ok(vec![
583 per_layer_elems * dtype.size_in_bytes();
584 cfg.num_hidden_layers
585 ])
586 }
587
588 fn num_layers(&self, config: &str) -> Result<usize> {
589 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
590 Ok(cfg.num_hidden_layers)
591 }
592
593 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
594 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
595
596 let cfg = ModelConfigMetadata {
597 max_seq_len: cfg.max_position_embeddings,
598 num_layers: cfg.num_hidden_layers,
599 hidden_size: cfg.hidden_size,
600 num_kv_heads: cfg.num_key_value_heads,
601 num_attn_heads: cfg.num_attention_heads,
602 sliding_window: cfg.sliding_window,
603 k_head_dim: cfg.head_dim(),
604 v_head_dim: cfg.head_dim(),
605 };
606
607 Ok(Box::new(cfg))
608 }
609}
610
611pub struct GemmaLoader;
617
618impl NormalModelLoader for GemmaLoader {
619 fn load(
620 &self,
621 config: &str,
622 vb: ShardedVarBuilder,
623 normal_loading_metadata: NormalLoadingMetadata,
624 attention_mechanism: AttentionImplementation,
625 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
626 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
627
628 Ok(Box::new(models::gemma::Model::new(
629 &cfg,
630 vb,
631 self.is_gptx(config)?,
632 normal_loading_metadata,
633 attention_mechanism,
634 )?))
635 }
636 fn load_xlora(
637 &self,
638 config: &str,
639 vb: ShardedVarBuilder,
640 lora_config: &[((String, String), LoraConfig)],
641 xlora_config: Option<XLoraConfig>,
642 xlora_ordering: Ordering,
643 normal_loading_metadata: NormalLoadingMetadata,
644 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
645 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
646 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
647
648 Ok(Box::new(xlora_models::XLoraGemma::new(
649 &cfg,
650 vb,
651 lora_config,
652 xlora_config,
653 xlora_ordering,
654 self.is_gptx(config)?,
655 normal_loading_metadata,
656 preload_adapters,
657 )?))
658 }
659 fn is_gptx(&self, _: &str) -> Result<bool> {
660 Ok(true)
661 }
662 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
663 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
664 Ok(Box::new(cfg))
665 }
666}
667
668impl IsqModelLoader for GemmaLoader {
669 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
670 Ok(vec![
671 Regex::new(r"lm_head\.(weight|bias)$")?,
672 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
674 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
675 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
676 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
677 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
679 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
680 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
681 ])
682 }
683 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
684 self.isq_layer_regexes(config)
685 }
686}
687
688impl DeviceMappedModelLoader for GemmaLoader {
689 fn mapped_max_act_size_elems(
690 &self,
691 config: &str,
692 params: &AutoDeviceMapParams,
693 ) -> Result<usize> {
694 let AutoDeviceMapParams::Text {
695 max_seq_len,
696 max_batch_size,
697 } = params
698 else {
699 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
700 };
701
702 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
703
704 Ok(
705 max_batch_size
706 * cfg.num_attention_heads
707 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
708 )
709 }
710 fn non_mapped_max_act_size_elems(
711 &self,
712 _config: &str,
713 _params: &AutoDeviceMapParams,
714 ) -> Result<usize> {
715 Ok(0)
716 }
717
718 fn non_mapped_size_in_bytes(
719 &self,
720 config: &str,
721 dtype: DType,
722 weight_pack_factor: usize,
723 _matformer_config: Option<&MatformerSliceConfig>,
724 ) -> Result<usize> {
725 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
726
727 let elems = {
728 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
729 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
731 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
732 } else {
733 0
734 };
735 let norm = cfg.hidden_size;
736 embed_tokens + lm_head + norm
737 };
738 Ok(elems * dtype.size_in_bytes())
739 }
740
741 fn layer_sizes_in_bytes(
742 &self,
743 config: &str,
744 dtype: DType,
745 weight_pack_factor: usize,
746 _matformer_config: Option<&MatformerSliceConfig>,
747 ) -> Result<Vec<usize>> {
748 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
749
750 let per_layer_elems = {
751 let input_layernorm = cfg.hidden_size;
752 let post_attention_layernorm = cfg.hidden_size;
753
754 let size_in = cfg.hidden_size;
755 let size_q = cfg.head_dim * cfg.num_attention_heads;
756 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
757 let q_proj =
758 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
759 let k_proj =
760 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
761 let v_proj =
762 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
763 let o_proj =
764 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
765
766 let h_size = cfg.hidden_size;
767 let i_size = cfg.intermediate_size;
768 let gate_proj = h_size * i_size / weight_pack_factor;
769 let up_proj = h_size * i_size / weight_pack_factor;
770 let down_proj = i_size * h_size / weight_pack_factor;
771
772 input_layernorm
773 + post_attention_layernorm
774 + q_proj
775 + k_proj
776 + v_proj
777 + o_proj
778 + gate_proj
779 + up_proj
780 + down_proj
781 };
782 Ok(vec![
783 per_layer_elems * dtype.size_in_bytes();
784 cfg.num_hidden_layers
785 ])
786 }
787
788 fn num_layers(&self, config: &str) -> Result<usize> {
789 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
790 Ok(cfg.num_hidden_layers)
791 }
792
793 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
794 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
795
796 let cfg = ModelConfigMetadata {
797 max_seq_len: cfg.max_position_embeddings,
798 num_layers: cfg.num_hidden_layers,
799 hidden_size: cfg.hidden_size,
800 num_kv_heads: cfg.num_key_value_heads,
801 num_attn_heads: cfg.num_attention_heads,
802 sliding_window: None,
803 k_head_dim: cfg.head_dim,
804 v_head_dim: cfg.head_dim,
805 };
806
807 Ok(Box::new(cfg))
808 }
809}
810
811pub struct LlamaLoader;
817
818impl NormalModelLoader for LlamaLoader {
819 fn load(
820 &self,
821 config: &str,
822 vb: ShardedVarBuilder,
823 normal_loading_metadata: NormalLoadingMetadata,
824 attention_mechanism: AttentionImplementation,
825 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
826 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
827
828 Ok(Box::new(models::llama::Llama::new(
829 &cfg,
830 vb,
831 self.is_gptx(config)?,
832 normal_loading_metadata,
833 attention_mechanism,
834 )?))
835 }
836 fn load_xlora(
837 &self,
838 config: &str,
839 vb: ShardedVarBuilder,
840 lora_config: &[((String, String), LoraConfig)],
841 xlora_config: Option<XLoraConfig>,
842 xlora_ordering: Ordering,
843 normal_loading_metadata: NormalLoadingMetadata,
844 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
845 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
846 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
847
848 Ok(Box::new(xlora_models::XLoraLlama::new(
849 &cfg,
850 vb,
851 lora_config,
852 xlora_config,
853 xlora_ordering,
854 self.is_gptx(config)?,
855 normal_loading_metadata,
856 preload_adapters,
857 )?))
858 }
859 fn is_gptx(&self, _: &str) -> Result<bool> {
860 Ok(true)
861 }
862 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
863 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
864 Ok(Box::new(cfg))
865 }
866}
867
868impl IsqModelLoader for LlamaLoader {
869 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
870 Ok(vec![
871 Regex::new(r"lm_head\.(weight|bias)$")?,
872 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
874 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
875 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
876 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
877 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
879 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
880 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
881 ])
882 }
883 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
884 self.isq_layer_regexes(config)
885 }
886}
887
888impl DeviceMappedModelLoader for LlamaLoader {
889 fn mapped_max_act_size_elems(
890 &self,
891 config: &str,
892 params: &AutoDeviceMapParams,
893 ) -> Result<usize> {
894 let AutoDeviceMapParams::Text {
895 max_seq_len,
896 max_batch_size,
897 } = params
898 else {
899 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
900 };
901
902 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
903
904 Ok(
905 max_batch_size
906 * cfg.num_attention_heads
907 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
908 )
909 }
910 fn non_mapped_max_act_size_elems(
911 &self,
912 _config: &str,
913 _params: &AutoDeviceMapParams,
914 ) -> Result<usize> {
915 Ok(0)
916 }
917
918 fn non_mapped_size_in_bytes(
919 &self,
920 config: &str,
921 dtype: DType,
922 weight_pack_factor: usize,
923 _matformer_config: Option<&MatformerSliceConfig>,
924 ) -> Result<usize> {
925 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
926
927 let elems = {
928 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
929 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
931 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
932 } else {
933 0
934 };
935 let norm = cfg.hidden_size;
936 embed_tokens + lm_head + norm
937 };
938 Ok(elems * dtype.size_in_bytes())
939 }
940
941 fn layer_sizes_in_bytes(
942 &self,
943 config: &str,
944 dtype: DType,
945 weight_pack_factor: usize,
946 _matformer_config: Option<&MatformerSliceConfig>,
947 ) -> Result<Vec<usize>> {
948 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
949
950 let per_layer_elems = {
951 let input_layernorm = cfg.hidden_size;
952 let post_attention_layernorm = cfg.hidden_size;
953
954 let size_in = cfg.hidden_size;
955 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
956 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
957 let q_proj = size_in * size_q / weight_pack_factor;
958 let k_proj = size_in * size_kv / weight_pack_factor;
959 let v_proj = size_in * size_kv / weight_pack_factor;
960 let o_proj = size_q * size_in / weight_pack_factor;
961
962 let h_size = cfg.hidden_size;
963 let i_size = cfg.intermediate_size;
964 let gate_proj = h_size * i_size / weight_pack_factor;
965 let up_proj = h_size * i_size / weight_pack_factor;
966 let down_proj = i_size * h_size / weight_pack_factor;
967
968 input_layernorm
969 + post_attention_layernorm
970 + q_proj
971 + k_proj
972 + v_proj
973 + o_proj
974 + gate_proj
975 + up_proj
976 + down_proj
977 };
978 Ok(vec![
979 per_layer_elems * dtype.size_in_bytes();
980 cfg.num_hidden_layers
981 ])
982 }
983
984 fn num_layers(&self, config: &str) -> Result<usize> {
985 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
986
987 Ok(cfg.num_hidden_layers)
988 }
989 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
990 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
991
992 let cfg = ModelConfigMetadata {
993 max_seq_len: cfg.max_position_embeddings,
994 num_layers: cfg.num_hidden_layers,
995 hidden_size: cfg.hidden_size,
996 num_kv_heads: cfg.num_key_value_heads,
997 num_attn_heads: cfg.num_attention_heads,
998 sliding_window: None,
999 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1000 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1001 };
1002
1003 Ok(Box::new(cfg))
1004 }
1005}
1006
1007pub struct MixtralLoader;
1010
1011impl NormalModelLoader for MixtralLoader {
1012 fn load(
1013 &self,
1014 config: &str,
1015 vb: ShardedVarBuilder,
1016 normal_loading_metadata: NormalLoadingMetadata,
1017 attention_mechanism: AttentionImplementation,
1018 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1019 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1020
1021 Ok(Box::new(models::mixtral::Model::new(
1022 &cfg,
1023 vb,
1024 self.is_gptx(config)?,
1025 normal_loading_metadata,
1026 attention_mechanism,
1027 )?))
1028 }
1029 fn load_xlora(
1030 &self,
1031 config: &str,
1032 vb: ShardedVarBuilder,
1033 lora_config: &[((String, String), LoraConfig)],
1034 xlora_config: Option<XLoraConfig>,
1035 xlora_ordering: Ordering,
1036 normal_loading_metadata: NormalLoadingMetadata,
1037 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1038 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1039 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1040
1041 Ok(Box::new(xlora_models::XLoraMixtral::new(
1042 &cfg,
1043 vb,
1044 lora_config,
1045 xlora_config,
1046 xlora_ordering,
1047 self.is_gptx(config)?,
1048 normal_loading_metadata,
1049 preload_adapters,
1050 )?))
1051 }
1052 fn is_gptx(&self, _: &str) -> Result<bool> {
1053 Ok(true)
1054 }
1055 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1056 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1057
1058 Ok(Box::new(cfg))
1059 }
1060}
1061
1062impl IsqModelLoader for MixtralLoader {
1063 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1064 Ok(vec![
1065 Regex::new(r"lm_head\.(weight|bias)$")?,
1066 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1068 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1069 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1070 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1071 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1073 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1074 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1075 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1076 ])
1077 }
1078 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1079 self.isq_layer_regexes(config)
1080 }
1081}
1082
1083impl DeviceMappedModelLoader for MixtralLoader {
1084 fn mapped_max_act_size_elems(
1085 &self,
1086 config: &str,
1087 params: &AutoDeviceMapParams,
1088 ) -> Result<usize> {
1089 let AutoDeviceMapParams::Text {
1090 max_seq_len,
1091 max_batch_size,
1092 } = params
1093 else {
1094 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1095 };
1096
1097 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1098
1099 Ok(
1100 max_batch_size
1101 * cfg.num_attention_heads
1102 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1103 )
1104 }
1105 fn non_mapped_max_act_size_elems(
1106 &self,
1107 _config: &str,
1108 _params: &AutoDeviceMapParams,
1109 ) -> Result<usize> {
1110 Ok(0)
1111 }
1112
1113 fn non_mapped_size_in_bytes(
1114 &self,
1115 config: &str,
1116 dtype: DType,
1117 weight_pack_factor: usize,
1118 _matformer_config: Option<&MatformerSliceConfig>,
1119 ) -> Result<usize> {
1120 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1121
1122 let elems = {
1123 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1124 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1126 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1127 } else {
1128 0
1129 };
1130 let norm = cfg.hidden_size;
1131 embed_tokens + lm_head + norm
1132 };
1133 Ok(elems * dtype.size_in_bytes())
1134 }
1135
1136 fn layer_sizes_in_bytes(
1137 &self,
1138 config: &str,
1139 dtype: DType,
1140 weight_pack_factor: usize,
1141 _matformer_config: Option<&MatformerSliceConfig>,
1142 ) -> Result<Vec<usize>> {
1143 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1144
1145 let per_layer_elems = {
1146 let input_layernorm = cfg.hidden_size;
1147 let post_attention_layernorm = cfg.hidden_size;
1148
1149 let size_in = cfg.hidden_size;
1150 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1151 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1152 let q_proj = size_in * size_q / weight_pack_factor;
1153 let k_proj = size_in * size_kv / weight_pack_factor;
1154 let v_proj = size_in * size_kv / weight_pack_factor;
1155 let o_proj = size_q * size_in / weight_pack_factor;
1156
1157 let moe_block = {
1158 let gate = cfg.hidden_size * cfg.num_local_experts;
1159 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1161 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1162 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1163 gate + cfg.num_local_experts * w1
1164 + cfg.num_local_experts * w2
1165 + cfg.num_local_experts * w3
1166 };
1167
1168 input_layernorm
1169 + post_attention_layernorm
1170 + q_proj
1171 + k_proj
1172 + v_proj
1173 + o_proj
1174 + moe_block
1175 };
1176 Ok(vec![
1177 per_layer_elems * dtype.size_in_bytes();
1178 cfg.num_hidden_layers
1179 ])
1180 }
1181
1182 fn num_layers(&self, config: &str) -> Result<usize> {
1183 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1184
1185 Ok(cfg.num_hidden_layers)
1186 }
1187
1188 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1189 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1190
1191 let cfg = ModelConfigMetadata {
1192 max_seq_len: cfg.max_position_embeddings,
1193 num_layers: cfg.num_hidden_layers,
1194 hidden_size: cfg.hidden_size,
1195 num_kv_heads: cfg.num_key_value_heads,
1196 num_attn_heads: cfg.num_attention_heads,
1197 sliding_window: cfg.sliding_window,
1198 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1199 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1200 };
1201
1202 Ok(Box::new(cfg))
1203 }
1204}
1205
1206pub struct Phi2Loader;
1212
1213impl NormalModelLoader for Phi2Loader {
1214 fn load(
1215 &self,
1216 config: &str,
1217 vb: ShardedVarBuilder,
1218 normal_loading_metadata: NormalLoadingMetadata,
1219 attention_mechanism: AttentionImplementation,
1220 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1221 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1222
1223 Ok(Box::new(models::phi2::Model::new(
1224 &cfg,
1225 vb,
1226 self.is_gptx(config)?,
1227 normal_loading_metadata,
1228 attention_mechanism,
1229 )?))
1230 }
1231 fn load_xlora(
1232 &self,
1233 config: &str,
1234 vb: ShardedVarBuilder,
1235 lora_config: &[((String, String), LoraConfig)],
1236 xlora_config: Option<XLoraConfig>,
1237 xlora_ordering: Ordering,
1238 normal_loading_metadata: NormalLoadingMetadata,
1239 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1240 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1241 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243 Ok(Box::new(xlora_models::XLoraPhi2::new(
1244 &cfg,
1245 vb,
1246 lora_config,
1247 xlora_config,
1248 xlora_ordering,
1249 self.is_gptx(config)?,
1250 normal_loading_metadata,
1251 preload_adapters,
1252 )?))
1253 }
1254 fn is_gptx(&self, _: &str) -> Result<bool> {
1255 Ok(true)
1256 }
1257 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1258 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1259
1260 Ok(Box::new(cfg))
1261 }
1262}
1263
1264impl IsqModelLoader for Phi2Loader {
1265 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1266 Ok(vec![
1267 Regex::new(r"lm_head\.(weight|bias)$")?,
1268 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1270 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1271 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1272 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1273 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1275 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1276 ])
1277 }
1278 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1279 self.isq_layer_regexes(config)
1280 }
1281}
1282
1283impl DeviceMappedModelLoader for Phi2Loader {
1284 fn mapped_max_act_size_elems(
1285 &self,
1286 config: &str,
1287 params: &AutoDeviceMapParams,
1288 ) -> Result<usize> {
1289 let AutoDeviceMapParams::Text {
1290 max_seq_len,
1291 max_batch_size,
1292 } = params
1293 else {
1294 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1295 };
1296
1297 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1298
1299 Ok(
1300 max_batch_size
1301 * cfg.num_attention_heads
1302 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1303 )
1304 }
1305 fn non_mapped_max_act_size_elems(
1306 &self,
1307 _config: &str,
1308 _params: &AutoDeviceMapParams,
1309 ) -> Result<usize> {
1310 Ok(0)
1311 }
1312
1313 fn non_mapped_size_in_bytes(
1314 &self,
1315 config: &str,
1316 dtype: DType,
1317 weight_pack_factor: usize,
1318 _matformer_config: Option<&MatformerSliceConfig>,
1319 ) -> Result<usize> {
1320 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1321
1322 let elems = {
1323 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1324 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1326 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1327 } else {
1328 0
1329 };
1330 let norm = cfg.hidden_size;
1331 embed_tokens + lm_head + norm
1332 };
1333 Ok(elems * dtype.size_in_bytes())
1334 }
1335
1336 fn layer_sizes_in_bytes(
1337 &self,
1338 config: &str,
1339 dtype: DType,
1340 weight_pack_factor: usize,
1341 _matformer_config: Option<&MatformerSliceConfig>,
1342 ) -> Result<Vec<usize>> {
1343 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1344
1345 let per_layer_elems = {
1346 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1347
1348 let size_in = cfg.hidden_size;
1349 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1350 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1351 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1352 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1353 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1354 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1355 let (q_norm, k_norm) = if cfg.qk_layernorm {
1356 (cfg.head_dim(), cfg.head_dim())
1357 } else {
1358 (0, 0)
1359 };
1360
1361 let h_size = cfg.hidden_size;
1362 let i_size = cfg.intermediate_size;
1363 let fc1 = h_size * i_size / weight_pack_factor;
1364 let fc2 = h_size * i_size / weight_pack_factor;
1365
1366 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1367 };
1368 Ok(vec![
1369 per_layer_elems * dtype.size_in_bytes();
1370 cfg.num_hidden_layers
1371 ])
1372 }
1373
1374 fn num_layers(&self, config: &str) -> Result<usize> {
1375 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377 Ok(cfg.num_hidden_layers)
1378 }
1379
1380 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1381 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1382
1383 let cfg = ModelConfigMetadata {
1384 max_seq_len: cfg.max_position_embeddings,
1385 num_layers: cfg.num_hidden_layers,
1386 hidden_size: cfg.hidden_size,
1387 num_kv_heads: cfg.num_key_value_heads(),
1388 num_attn_heads: cfg.num_attention_heads,
1389 sliding_window: None,
1390 k_head_dim: cfg.head_dim(),
1391 v_head_dim: cfg.head_dim(),
1392 };
1393
1394 Ok(Box::new(cfg))
1395 }
1396}
1397
1398pub struct Phi3Loader;
1404
1405impl NormalModelLoader for Phi3Loader {
1406 fn load(
1407 &self,
1408 config: &str,
1409 vb: ShardedVarBuilder,
1410 normal_loading_metadata: NormalLoadingMetadata,
1411 attention_mechanism: AttentionImplementation,
1412 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1413 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1414
1415 Ok(Box::new(models::phi3::Model::new(
1416 &cfg,
1417 vb,
1418 self.is_gptx(config)?,
1419 normal_loading_metadata,
1420 attention_mechanism,
1421 )?))
1422 }
1423 fn load_xlora(
1424 &self,
1425 config: &str,
1426 vb: ShardedVarBuilder,
1427 lora_config: &[((String, String), LoraConfig)],
1428 xlora_config: Option<XLoraConfig>,
1429 xlora_ordering: Ordering,
1430 normal_loading_metadata: NormalLoadingMetadata,
1431 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1432 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1433 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1434
1435 Ok(Box::new(xlora_models::XLoraPhi3::new(
1436 &cfg,
1437 vb,
1438 lora_config,
1439 xlora_config,
1440 xlora_ordering,
1441 self.is_gptx(config)?,
1442 normal_loading_metadata,
1443 preload_adapters,
1444 )?))
1445 }
1446 fn is_gptx(&self, _: &str) -> Result<bool> {
1447 Ok(true)
1448 }
1449 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1450 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1451
1452 Ok(Box::new(cfg))
1453 }
1454}
1455
1456impl IsqModelLoader for Phi3Loader {
1457 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1458 Ok(vec![
1459 Regex::new(r"lm_head\.(weight|bias)$")?,
1460 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1462 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1463 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1465 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1466 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1467 ])
1468 }
1469 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1470 self.isq_layer_regexes(config)
1471 }
1472}
1473
1474impl DeviceMappedModelLoader for Phi3Loader {
1475 fn mapped_max_act_size_elems(
1476 &self,
1477 config: &str,
1478 params: &AutoDeviceMapParams,
1479 ) -> Result<usize> {
1480 let AutoDeviceMapParams::Text {
1481 max_seq_len,
1482 max_batch_size,
1483 } = params
1484 else {
1485 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1486 };
1487
1488 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1489
1490 Ok(
1491 max_batch_size
1492 * cfg.num_attention_heads
1493 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1494 )
1495 }
1496 fn non_mapped_max_act_size_elems(
1497 &self,
1498 _config: &str,
1499 _params: &AutoDeviceMapParams,
1500 ) -> Result<usize> {
1501 Ok(0)
1502 }
1503
1504 fn non_mapped_size_in_bytes(
1505 &self,
1506 config: &str,
1507 dtype: DType,
1508 weight_pack_factor: usize,
1509 _matformer_config: Option<&MatformerSliceConfig>,
1510 ) -> Result<usize> {
1511 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513 let elems = {
1514 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1515 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1517 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1518 } else {
1519 0
1520 };
1521 let norm = cfg.hidden_size;
1522 embed_tokens + lm_head + norm
1523 };
1524 Ok(elems * dtype.size_in_bytes())
1525 }
1526
1527 fn layer_sizes_in_bytes(
1528 &self,
1529 config: &str,
1530 dtype: DType,
1531 weight_pack_factor: usize,
1532 _matformer_config: Option<&MatformerSliceConfig>,
1533 ) -> Result<Vec<usize>> {
1534 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1535
1536 let per_layer_elems = {
1537 let input_layernorm = cfg.hidden_size;
1538 let post_attention_layernorm = cfg.hidden_size;
1539
1540 let size_in = cfg.hidden_size;
1541 let head_dim = cfg.head_dim();
1542 let op_size =
1543 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1544 let qkv_proj = size_in * op_size / weight_pack_factor;
1545 let o_proj =
1546 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1547
1548 let h_size = cfg.hidden_size;
1549 let i_size = cfg.intermediate_size;
1550 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1551 let down_proj = h_size * i_size / weight_pack_factor;
1552
1553 input_layernorm
1554 + post_attention_layernorm
1555 + qkv_proj
1556 + o_proj
1557 + gate_up_proj
1558 + down_proj
1559 };
1560 Ok(vec![
1561 per_layer_elems * dtype.size_in_bytes();
1562 cfg.num_hidden_layers
1563 ])
1564 }
1565
1566 fn num_layers(&self, config: &str) -> Result<usize> {
1567 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569 Ok(cfg.num_hidden_layers)
1570 }
1571
1572 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1573 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1574
1575 let cfg = ModelConfigMetadata {
1576 max_seq_len: cfg.max_position_embeddings,
1577 num_layers: cfg.num_hidden_layers,
1578 hidden_size: cfg.hidden_size,
1579 num_kv_heads: cfg.num_key_value_heads,
1580 num_attn_heads: cfg.num_attention_heads,
1581 sliding_window: cfg.sliding_window,
1582 k_head_dim: cfg.head_dim(),
1583 v_head_dim: cfg.head_dim(),
1584 };
1585
1586 Ok(Box::new(cfg))
1587 }
1588}
1589
1590pub struct Qwen2Loader;
1596
1597impl NormalModelLoader for Qwen2Loader {
1598 fn load(
1599 &self,
1600 config: &str,
1601 vb: ShardedVarBuilder,
1602 normal_loading_metadata: NormalLoadingMetadata,
1603 attention_mechanism: AttentionImplementation,
1604 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1605 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1606
1607 Ok(Box::new(models::qwen2::Model::new(
1608 &cfg,
1609 vb,
1610 self.is_gptx(config)?,
1611 normal_loading_metadata,
1612 attention_mechanism,
1613 )?))
1614 }
1615 fn load_xlora(
1616 &self,
1617 _config: &str,
1618 _vb: ShardedVarBuilder,
1619 _lora_config: &[((String, String), LoraConfig)],
1620 _xlora_config: Option<XLoraConfig>,
1621 _xlora_ordering: Ordering,
1622 _normal_loading_metadata: NormalLoadingMetadata,
1623 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1624 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1625 todo!()
1626 }
1627 fn is_gptx(&self, _: &str) -> Result<bool> {
1628 Ok(true)
1629 }
1630 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1631 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1632
1633 Ok(Box::new(cfg))
1634 }
1635}
1636
1637impl IsqModelLoader for Qwen2Loader {
1638 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1639 Ok(vec![
1640 Regex::new(r"lm_head\.(weight|bias)$")?,
1641 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1643 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1644 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1645 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1646 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1648 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1649 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1650 ])
1651 }
1652 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1653 self.isq_layer_regexes(config)
1654 }
1655}
1656
1657impl DeviceMappedModelLoader for Qwen2Loader {
1658 fn mapped_max_act_size_elems(
1659 &self,
1660 config: &str,
1661 params: &AutoDeviceMapParams,
1662 ) -> Result<usize> {
1663 let AutoDeviceMapParams::Text {
1664 max_seq_len,
1665 max_batch_size,
1666 } = params
1667 else {
1668 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1669 };
1670
1671 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1672
1673 Ok(
1674 max_batch_size
1675 * cfg.num_attention_heads
1676 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1677 )
1678 }
1679 fn non_mapped_max_act_size_elems(
1680 &self,
1681 _config: &str,
1682 _params: &AutoDeviceMapParams,
1683 ) -> Result<usize> {
1684 Ok(0)
1685 }
1686
1687 fn non_mapped_size_in_bytes(
1688 &self,
1689 config: &str,
1690 dtype: DType,
1691 weight_pack_factor: usize,
1692 _matformer_config: Option<&MatformerSliceConfig>,
1693 ) -> Result<usize> {
1694 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1695
1696 let elems = {
1697 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1698 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1700 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1701 } else {
1702 0
1703 };
1704 let norm = cfg.hidden_size;
1705 embed_tokens + lm_head + norm
1706 };
1707 Ok(elems * dtype.size_in_bytes())
1708 }
1709
1710 fn layer_sizes_in_bytes(
1711 &self,
1712 config: &str,
1713 dtype: DType,
1714 weight_pack_factor: usize,
1715 _matformer_config: Option<&MatformerSliceConfig>,
1716 ) -> Result<Vec<usize>> {
1717 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1718
1719 let per_layer_elems = {
1720 let input_layernorm = cfg.hidden_size;
1721 let post_attention_layernorm = cfg.hidden_size;
1722
1723 let size_in = cfg.hidden_size;
1724 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1725 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1726 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1727 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1728 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1729 let o_proj = size_q * size_in / weight_pack_factor;
1730
1731 let h_size = cfg.hidden_size;
1732 let i_size = cfg.intermediate_size;
1733 let gate_proj = h_size * i_size / weight_pack_factor;
1734 let up_proj = h_size * i_size / weight_pack_factor;
1735 let down_proj = i_size * h_size / weight_pack_factor;
1736
1737 input_layernorm
1738 + post_attention_layernorm
1739 + q_proj
1740 + k_proj
1741 + v_proj
1742 + o_proj
1743 + gate_proj
1744 + up_proj
1745 + down_proj
1746 };
1747 Ok(vec![
1748 per_layer_elems * dtype.size_in_bytes();
1749 cfg.num_hidden_layers
1750 ])
1751 }
1752
1753 fn num_layers(&self, config: &str) -> Result<usize> {
1754 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756 Ok(cfg.num_hidden_layers)
1757 }
1758
1759 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1760 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1761
1762 let cfg = ModelConfigMetadata {
1763 max_seq_len: cfg.max_position_embeddings,
1764 num_layers: cfg.num_hidden_layers,
1765 hidden_size: cfg.hidden_size,
1766 num_kv_heads: cfg.num_key_value_heads,
1767 num_attn_heads: cfg.num_attention_heads,
1768 sliding_window: cfg.sliding_window,
1769 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1770 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1771 };
1772
1773 Ok(Box::new(cfg))
1774 }
1775}
1776
1777pub struct Gemma2Loader;
1783
1784impl NormalModelLoader for Gemma2Loader {
1785 fn load(
1786 &self,
1787 config: &str,
1788 vb: ShardedVarBuilder,
1789 normal_loading_metadata: NormalLoadingMetadata,
1790 attention_mechanism: AttentionImplementation,
1791 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1792 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1793
1794 Ok(Box::new(models::gemma2::Model::new(
1795 &cfg,
1796 vb,
1797 self.is_gptx(config)?,
1798 normal_loading_metadata,
1799 attention_mechanism,
1800 )?))
1801 }
1802 fn load_xlora(
1803 &self,
1804 config: &str,
1805 vb: ShardedVarBuilder,
1806 lora_config: &[((String, String), LoraConfig)],
1807 xlora_config: Option<XLoraConfig>,
1808 xlora_ordering: Ordering,
1809 normal_loading_metadata: NormalLoadingMetadata,
1810 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1811 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1812 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1813
1814 Ok(Box::new(xlora_models::XLoraGemma2::new(
1815 &cfg,
1816 vb,
1817 lora_config,
1818 xlora_config,
1819 xlora_ordering,
1820 self.is_gptx(config)?,
1821 normal_loading_metadata,
1822 preload_adapters,
1823 )?))
1824 }
1825 fn is_gptx(&self, _: &str) -> Result<bool> {
1826 Ok(true)
1827 }
1828 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1829 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1830
1831 Ok(Box::new(cfg))
1832 }
1833}
1834
1835impl IsqModelLoader for Gemma2Loader {
1836 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1837 Ok(vec![
1838 Regex::new(r"lm_head\.(weight|bias)$")?,
1839 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1841 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1842 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1843 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1844 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1846 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1847 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1848 ])
1849 }
1850 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1851 self.isq_layer_regexes(config)
1852 }
1853}
1854
1855impl DeviceMappedModelLoader for Gemma2Loader {
1856 fn mapped_max_act_size_elems(
1857 &self,
1858 config: &str,
1859 params: &AutoDeviceMapParams,
1860 ) -> Result<usize> {
1861 let AutoDeviceMapParams::Text {
1862 max_seq_len,
1863 max_batch_size,
1864 } = params
1865 else {
1866 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1867 };
1868
1869 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1870
1871 Ok(
1872 max_batch_size
1873 * cfg.num_attention_heads
1874 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1875 )
1876 }
1877 fn non_mapped_max_act_size_elems(
1878 &self,
1879 _config: &str,
1880 _params: &AutoDeviceMapParams,
1881 ) -> Result<usize> {
1882 Ok(0)
1883 }
1884
1885 fn non_mapped_size_in_bytes(
1886 &self,
1887 config: &str,
1888 dtype: DType,
1889 weight_pack_factor: usize,
1890 _matformer_config: Option<&MatformerSliceConfig>,
1891 ) -> Result<usize> {
1892 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1893
1894 let elems = {
1895 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1896 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1898 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1899 } else {
1900 0
1901 };
1902 let norm = cfg.hidden_size;
1903 embed_tokens + lm_head + norm
1904 };
1905 Ok(elems * dtype.size_in_bytes())
1906 }
1907
1908 fn layer_sizes_in_bytes(
1909 &self,
1910 config: &str,
1911 dtype: DType,
1912 weight_pack_factor: usize,
1913 _matformer_config: Option<&MatformerSliceConfig>,
1914 ) -> Result<Vec<usize>> {
1915 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1916
1917 let per_layer_elems = {
1918 let input_layernorm = cfg.hidden_size;
1919 let post_attention_layernorm = cfg.hidden_size;
1920
1921 let size_in = cfg.hidden_size;
1922 let size_q = cfg.head_dim * cfg.num_attention_heads;
1923 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1924 let q_proj =
1925 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1926 let k_proj =
1927 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1928 let v_proj =
1929 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1930 let o_proj =
1931 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1932
1933 let h_size = cfg.hidden_size;
1934 let i_size = cfg.intermediate_size;
1935 let gate_proj = h_size * i_size / weight_pack_factor;
1936 let up_proj = h_size * i_size / weight_pack_factor;
1937 let down_proj = i_size * h_size / weight_pack_factor;
1938
1939 input_layernorm
1940 + post_attention_layernorm
1941 + q_proj
1942 + k_proj
1943 + v_proj
1944 + o_proj
1945 + gate_proj
1946 + up_proj
1947 + down_proj
1948 };
1949 Ok(vec![
1950 per_layer_elems * dtype.size_in_bytes();
1951 cfg.num_hidden_layers
1952 ])
1953 }
1954
1955 fn num_layers(&self, config: &str) -> Result<usize> {
1956 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1957
1958 Ok(cfg.num_hidden_layers)
1959 }
1960 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1961 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1962
1963 let cfg = ModelConfigMetadata {
1964 max_seq_len: cfg.max_position_embeddings,
1965 num_layers: cfg.num_hidden_layers,
1966 hidden_size: cfg.hidden_size,
1967 num_kv_heads: cfg.num_key_value_heads,
1968 num_attn_heads: cfg.num_attention_heads,
1969 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1971 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1972 };
1973
1974 Ok(Box::new(cfg))
1975 }
1976}
1977
1978pub struct Starcoder2Loader;
1984
1985impl NormalModelLoader for Starcoder2Loader {
1986 fn load(
1987 &self,
1988 config: &str,
1989 vb: ShardedVarBuilder,
1990 normal_loading_metadata: NormalLoadingMetadata,
1991 attention_mechanism: AttentionImplementation,
1992 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1993 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1994
1995 Ok(Box::new(models::starcoder2::Model::new(
1996 &cfg,
1997 vb,
1998 self.is_gptx(config)?,
1999 normal_loading_metadata,
2000 attention_mechanism,
2001 )?))
2002 }
2003 fn load_xlora(
2004 &self,
2005 config: &str,
2006 vb: ShardedVarBuilder,
2007 lora_config: &[((String, String), LoraConfig)],
2008 xlora_config: Option<XLoraConfig>,
2009 xlora_ordering: Ordering,
2010 normal_loading_metadata: NormalLoadingMetadata,
2011 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2012 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2013 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2014
2015 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2016 &cfg,
2017 vb,
2018 lora_config,
2019 xlora_config,
2020 xlora_ordering,
2021 self.is_gptx(config)?,
2022 normal_loading_metadata,
2023 preload_adapters,
2024 )?))
2025 }
2026 fn is_gptx(&self, _: &str) -> Result<bool> {
2027 Ok(true)
2028 }
2029 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2030 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2031
2032 Ok(Box::new(cfg))
2033 }
2034}
2035
2036impl IsqModelLoader for Starcoder2Loader {
2037 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2038 Ok(vec![
2039 Regex::new(r"lm_head\.(weight|bias)$")?,
2040 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2042 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2043 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2044 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2045 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2047 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2048 ])
2049 }
2050 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2051 self.isq_layer_regexes(config)
2052 }
2053}
2054
2055impl DeviceMappedModelLoader for Starcoder2Loader {
2056 fn mapped_max_act_size_elems(
2057 &self,
2058 config: &str,
2059 params: &AutoDeviceMapParams,
2060 ) -> Result<usize> {
2061 let AutoDeviceMapParams::Text {
2062 max_seq_len,
2063 max_batch_size,
2064 } = params
2065 else {
2066 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2067 };
2068
2069 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071 Ok(
2072 max_batch_size
2073 * cfg.num_attention_heads
2074 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2075 )
2076 }
2077 fn non_mapped_max_act_size_elems(
2078 &self,
2079 _config: &str,
2080 _params: &AutoDeviceMapParams,
2081 ) -> Result<usize> {
2082 Ok(0)
2083 }
2084
2085 fn non_mapped_size_in_bytes(
2086 &self,
2087 config: &str,
2088 dtype: DType,
2089 weight_pack_factor: usize,
2090 _matformer_config: Option<&MatformerSliceConfig>,
2091 ) -> Result<usize> {
2092 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2093
2094 let elems = {
2095 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2096 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2098 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2099 } else {
2100 0
2101 };
2102 let norm = cfg.hidden_size + cfg.hidden_size;
2103 embed_tokens + lm_head + norm
2104 };
2105 Ok(elems * dtype.size_in_bytes())
2106 }
2107
2108 fn layer_sizes_in_bytes(
2109 &self,
2110 config: &str,
2111 dtype: DType,
2112 weight_pack_factor: usize,
2113 _matformer_config: Option<&MatformerSliceConfig>,
2114 ) -> Result<Vec<usize>> {
2115 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2116
2117 let per_layer_elems = {
2118 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2119 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2120
2121 let size_in = cfg.hidden_size;
2122 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2123 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2124 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2125 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2126 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2127 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2128
2129 let h_size = cfg.hidden_size;
2130 let i_size = cfg.intermediate_size;
2131 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2132 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2133
2134 input_layernorm
2135 + post_attention_layernorm
2136 + q_proj
2137 + k_proj
2138 + v_proj
2139 + o_proj
2140 + fc1
2141 + fc2
2142 };
2143 Ok(vec![
2144 per_layer_elems * dtype.size_in_bytes();
2145 cfg.num_hidden_layers
2146 ])
2147 }
2148
2149 fn num_layers(&self, config: &str) -> Result<usize> {
2150 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2151
2152 Ok(cfg.num_hidden_layers)
2153 }
2154
2155 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2156 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2157
2158 let cfg = ModelConfigMetadata {
2159 max_seq_len: cfg.max_position_embeddings,
2160 num_layers: cfg.num_hidden_layers,
2161 hidden_size: cfg.hidden_size,
2162 num_kv_heads: cfg.num_key_value_heads,
2163 num_attn_heads: cfg.num_attention_heads,
2164 sliding_window: cfg.sliding_window,
2165 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2166 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2167 };
2168
2169 Ok(Box::new(cfg))
2170 }
2171}
2172
2173pub struct Phi3_5MoELoader;
2179
2180impl NormalModelLoader for Phi3_5MoELoader {
2181 fn load(
2182 &self,
2183 config: &str,
2184 vb: ShardedVarBuilder,
2185 normal_loading_metadata: NormalLoadingMetadata,
2186 attention_mechanism: AttentionImplementation,
2187 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2188 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2189
2190 Ok(Box::new(models::phi3_5_moe::Model::new(
2191 &cfg,
2192 vb,
2193 self.is_gptx(config)?,
2194 normal_loading_metadata,
2195 attention_mechanism,
2196 )?))
2197 }
2198 fn load_xlora(
2199 &self,
2200 config: &str,
2201 vb: ShardedVarBuilder,
2202 lora_config: &[((String, String), LoraConfig)],
2203 xlora_config: Option<XLoraConfig>,
2204 xlora_ordering: Ordering,
2205 normal_loading_metadata: NormalLoadingMetadata,
2206 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2207 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2208 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2209
2210 Ok(Box::new(xlora_models::XLoraPhi3::new(
2211 &cfg,
2212 vb,
2213 lora_config,
2214 xlora_config,
2215 xlora_ordering,
2216 self.is_gptx(config)?,
2217 normal_loading_metadata,
2218 preload_adapters,
2219 )?))
2220 }
2221 fn is_gptx(&self, _: &str) -> Result<bool> {
2222 Ok(true)
2223 }
2224 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2225 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2226
2227 Ok(Box::new(cfg))
2228 }
2229}
2230
2231impl IsqModelLoader for Phi3_5MoELoader {
2232 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2233 Ok(vec![
2234 Regex::new(r"lm_head\.(weight|bias)$")?,
2235 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2237 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2238 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2239 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2240 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2242 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2243 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2244 ])
2245 }
2246 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2247 self.isq_layer_regexes(config)
2248 }
2249
2250 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2251 Ok(vec![
2252 Regex::new(r"lm_head\.(weight|bias)$")?,
2253 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2255 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2256 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2257 ])
2258 }
2259 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2260 self.isq_layer_regexes_moqe(config)
2261 }
2262}
2263
2264impl DeviceMappedModelLoader for Phi3_5MoELoader {
2265 fn mapped_max_act_size_elems(
2266 &self,
2267 config: &str,
2268 params: &AutoDeviceMapParams,
2269 ) -> Result<usize> {
2270 let AutoDeviceMapParams::Text {
2271 max_seq_len,
2272 max_batch_size,
2273 } = params
2274 else {
2275 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2276 };
2277
2278 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2279
2280 Ok(
2281 max_batch_size
2282 * cfg.num_attention_heads
2283 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2284 )
2285 }
2286 fn non_mapped_max_act_size_elems(
2287 &self,
2288 _config: &str,
2289 _params: &AutoDeviceMapParams,
2290 ) -> Result<usize> {
2291 Ok(0)
2292 }
2293
2294 fn non_mapped_size_in_bytes(
2295 &self,
2296 config: &str,
2297 dtype: DType,
2298 weight_pack_factor: usize,
2299 _matformer_config: Option<&MatformerSliceConfig>,
2300 ) -> Result<usize> {
2301 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2302
2303 let elems = {
2304 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2305 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2307 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2308 } else {
2309 0
2310 };
2311 let norm = cfg.hidden_size;
2312 embed_tokens + lm_head + norm
2313 };
2314 Ok(elems * dtype.size_in_bytes())
2315 }
2316
2317 fn layer_sizes_in_bytes(
2318 &self,
2319 config: &str,
2320 dtype: DType,
2321 weight_pack_factor: usize,
2322 _matformer_config: Option<&MatformerSliceConfig>,
2323 ) -> Result<Vec<usize>> {
2324 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2325
2326 let per_layer_elems = {
2327 let input_layernorm = cfg.hidden_size;
2328 let post_attention_layernorm = cfg.hidden_size;
2329
2330 let size_in = cfg.hidden_size;
2331 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2332 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2333 let q_proj =
2334 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2335 let k_proj =
2336 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2337 let v_proj =
2338 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2339 let o_proj =
2340 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2341
2342 let moe_block = {
2343 let gate = cfg.hidden_size * cfg.num_local_experts;
2344 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2346 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2347 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2348 gate + cfg.num_local_experts * w1
2349 + cfg.num_local_experts * w2
2350 + cfg.num_local_experts * w3
2351 };
2352
2353 input_layernorm
2354 + post_attention_layernorm
2355 + q_proj
2356 + k_proj
2357 + v_proj
2358 + o_proj
2359 + moe_block
2360 };
2361 Ok(vec![
2362 per_layer_elems * dtype.size_in_bytes();
2363 cfg.num_hidden_layers
2364 ])
2365 }
2366
2367 fn num_layers(&self, config: &str) -> Result<usize> {
2368 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2369
2370 Ok(cfg.num_hidden_layers)
2371 }
2372
2373 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2374 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2375
2376 let cfg = ModelConfigMetadata {
2377 max_seq_len: cfg.max_position_embeddings,
2378 num_layers: cfg.num_hidden_layers,
2379 hidden_size: cfg.hidden_size,
2380 num_kv_heads: cfg.num_key_value_heads,
2381 num_attn_heads: cfg.num_attention_heads,
2382 sliding_window: cfg.sliding_window,
2383 k_head_dim: cfg.head_dim(),
2384 v_head_dim: cfg.head_dim(),
2385 };
2386
2387 Ok(Box::new(cfg))
2388 }
2389}
2390
2391pub struct DeepSeekV2Loader;
2395
2396impl NormalModelLoader for DeepSeekV2Loader {
2397 fn load(
2398 &self,
2399 config: &str,
2400 vb: ShardedVarBuilder,
2401 normal_loading_metadata: NormalLoadingMetadata,
2402 attention_mechanism: AttentionImplementation,
2403 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2404 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2405
2406 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2407 &cfg,
2408 vb,
2409 self.is_gptx(config)?,
2410 normal_loading_metadata,
2411 attention_mechanism,
2412 )?))
2413 }
2414 fn load_xlora(
2415 &self,
2416 _config: &str,
2417 _vb: ShardedVarBuilder,
2418 _lora_config: &[((String, String), LoraConfig)],
2419 _xlora_config: Option<XLoraConfig>,
2420 _xlora_ordering: Ordering,
2421 _normal_loading_metadata: NormalLoadingMetadata,
2422 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2423 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2424 todo!()
2425 }
2426 fn is_gptx(&self, _: &str) -> Result<bool> {
2427 Ok(true)
2428 }
2429 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2430 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2431 Ok(Box::new(cfg))
2432 }
2433}
2434
2435impl IsqModelLoader for DeepSeekV2Loader {
2436 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2437 let mut data = vec![
2438 Regex::new(r"lm_head\.(weight|bias)$")?,
2439 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2441 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2442 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2443 ];
2444 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2445 if cfg.q_lora_rank.is_some() {
2446 data.extend(vec![
2447 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2448 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2449 ]);
2450 } else {
2451 data.push(Regex::new(
2452 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2453 )?);
2454 }
2455 for layer_idx in 0..cfg.num_hidden_layers {
2456 if cfg.n_routed_experts.is_some()
2457 && layer_idx >= cfg.first_k_dense_replace
2458 && layer_idx % cfg.moe_layer_freq == 0
2459 {
2460 for i in 0..cfg.n_routed_experts.unwrap() {
2461 data.extend(vec![
2462 Regex::new(&format!(
2463 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2464 ))?,
2465 Regex::new(&format!(
2466 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2467 ))?,
2468 Regex::new(&format!(
2469 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2470 ))?,
2471 ]);
2472 }
2473 if cfg.n_shared_experts.is_some() {
2474 data.extend(vec![
2475 Regex::new(&format!(
2476 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2477 ))?,
2478 Regex::new(&format!(
2479 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2480 ))?,
2481 Regex::new(&format!(
2482 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2483 ))?,
2484 ]);
2485 }
2486 } else {
2487 data.extend(vec![
2488 Regex::new(&format!(
2489 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2490 ))?,
2491 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2492 Regex::new(&format!(
2493 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2494 ))?,
2495 ]);
2496 };
2497 }
2498 Ok(data)
2499 }
2500 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2501 self.isq_layer_regexes(config)
2502 }
2503
2504 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2505 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2506 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2507 for layer_idx in 0..cfg.num_hidden_layers {
2508 if cfg.n_routed_experts.is_some()
2509 && layer_idx >= cfg.first_k_dense_replace
2510 && layer_idx % cfg.moe_layer_freq == 0
2511 {
2512 for i in 0..cfg.n_routed_experts.unwrap() {
2513 data.extend(vec![
2514 Regex::new(&format!(
2515 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2516 ))?,
2517 Regex::new(&format!(
2518 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2519 ))?,
2520 Regex::new(&format!(
2521 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2522 ))?,
2523 ]);
2524 }
2525 if cfg.n_shared_experts.is_some() {
2526 data.extend(vec![
2527 Regex::new(&format!(
2528 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2529 ))?,
2530 Regex::new(&format!(
2531 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2532 ))?,
2533 Regex::new(&format!(
2534 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2535 ))?,
2536 ]);
2537 }
2538 } else {
2539 data.extend(vec![
2540 Regex::new(&format!(
2541 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2542 ))?,
2543 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2544 Regex::new(&format!(
2545 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2546 ))?,
2547 ]);
2548 };
2549 }
2550 Ok(data)
2551 }
2552 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2553 self.isq_layer_regexes_moqe(config)
2554 }
2555}
2556
2557impl DeviceMappedModelLoader for DeepSeekV2Loader {
2558 fn mapped_max_act_size_elems(
2559 &self,
2560 config: &str,
2561 params: &AutoDeviceMapParams,
2562 ) -> Result<usize> {
2563 let AutoDeviceMapParams::Text {
2564 max_seq_len,
2565 max_batch_size,
2566 } = params
2567 else {
2568 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2569 };
2570
2571 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2572
2573 Ok(
2574 max_batch_size
2575 * cfg.num_attention_heads
2576 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2577 )
2578 }
2579 fn non_mapped_max_act_size_elems(
2580 &self,
2581 _config: &str,
2582 _params: &AutoDeviceMapParams,
2583 ) -> Result<usize> {
2584 Ok(0)
2585 }
2586
2587 fn non_mapped_size_in_bytes(
2588 &self,
2589 config: &str,
2590 dtype: DType,
2591 weight_pack_factor: usize,
2592 _matformer_config: Option<&MatformerSliceConfig>,
2593 ) -> Result<usize> {
2594 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2595 let elems = {
2596 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2597 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2599 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2600 } else {
2601 0
2602 };
2603 let norm = cfg.hidden_size;
2604 embed_tokens + lm_head + norm
2605 };
2606 Ok(elems * dtype.size_in_bytes())
2607 }
2608
2609 fn layer_sizes_in_bytes(
2610 &self,
2611 config: &str,
2612 dtype: DType,
2613 weight_pack_factor: usize,
2614 _matformer_config: Option<&MatformerSliceConfig>,
2615 ) -> Result<Vec<usize>> {
2616 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2617 let mut per_layer_elems = Vec::new();
2618
2619 for layer_idx in 0..cfg.num_hidden_layers {
2620 let input_layernorm = cfg.hidden_size;
2621 let post_attention_layernorm = cfg.hidden_size;
2622
2623 let q_proj = match cfg.q_lora_rank {
2624 Some(lora_rank) => {
2625 let a = cfg.hidden_size * lora_rank;
2626 let norm = lora_rank;
2627 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2628 a + norm + b
2629 }
2630 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2631 };
2632 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2633 / weight_pack_factor
2634 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2635 let kv_a_layernorm = cfg.kv_lora_rank;
2636 let kv_b_proj = cfg.kv_lora_rank
2637 * cfg.num_attention_heads
2638 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2639 / weight_pack_factor;
2640 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2641 / weight_pack_factor
2642 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2643
2644 let moe_block = {
2645 let mut sum = 0;
2646 if cfg.n_routed_experts.is_some()
2647 && layer_idx >= cfg.first_k_dense_replace
2648 && layer_idx % cfg.moe_layer_freq == 0
2649 {
2650 let h_size = cfg.hidden_size;
2651 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2652 * cfg.n_routed_experts.unwrap();
2653 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2654 * cfg.n_routed_experts.unwrap();
2655 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2656 * cfg.n_routed_experts.unwrap();
2657 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2658 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2659 / weight_pack_factor;
2660 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2661 / weight_pack_factor;
2662 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2663 / weight_pack_factor;
2664 gate_proj + up_proj + down_proj
2665 } else {
2666 0
2667 };
2668 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2669 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2670 } else {
2671 let h_size = cfg.hidden_size;
2672 let i_size = cfg.intermediate_size;
2673 let gate_proj = h_size * i_size / weight_pack_factor;
2674 let up_proj = h_size * i_size / weight_pack_factor;
2675 let down_proj = i_size * h_size / weight_pack_factor;
2676 sum += gate_proj + up_proj + down_proj;
2677 }
2678 sum
2679 };
2680
2681 per_layer_elems.push(
2682 input_layernorm
2683 + post_attention_layernorm
2684 + q_proj
2685 + kv_a_layernorm
2686 + kv_a_proj_with_mqa
2687 + kv_b_proj
2688 + o_proj
2689 + moe_block,
2690 );
2691 }
2692
2693 Ok(per_layer_elems
2694 .into_iter()
2695 .map(|x| x * dtype.size_in_bytes())
2696 .collect())
2697 }
2698
2699 fn num_layers(&self, config: &str) -> Result<usize> {
2700 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2701 Ok(cfg.num_hidden_layers)
2702 }
2703
2704 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2705 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2706
2707 let cfg = ModelConfigMetadata {
2708 max_seq_len: cfg.max_position_embeddings,
2709 num_layers: cfg.num_hidden_layers,
2710 hidden_size: cfg.hidden_size,
2711 num_kv_heads: cfg.num_attention_heads,
2712 num_attn_heads: cfg.num_attention_heads,
2713 sliding_window: None,
2714 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2715 v_head_dim: cfg.v_head_dim,
2716 };
2717
2718 Ok(Box::new(cfg))
2719 }
2720}
2721
2722pub struct DeepSeekV3Loader;
2726
2727impl NormalModelLoader for DeepSeekV3Loader {
2728 fn load(
2729 &self,
2730 config: &str,
2731 vb: ShardedVarBuilder,
2732 normal_loading_metadata: NormalLoadingMetadata,
2733 attention_mechanism: AttentionImplementation,
2734 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2735 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2736 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2737 &cfg,
2738 vb,
2739 self.is_gptx(config)?,
2740 normal_loading_metadata,
2741 attention_mechanism,
2742 )?))
2743 }
2744 fn load_xlora(
2745 &self,
2746 _config: &str,
2747 _vb: ShardedVarBuilder,
2748 _lora_config: &[((String, String), LoraConfig)],
2749 _xlora_config: Option<XLoraConfig>,
2750 _xlora_ordering: Ordering,
2751 _normal_loading_metadata: NormalLoadingMetadata,
2752 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2753 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2754 todo!()
2755 }
2756 fn is_gptx(&self, _: &str) -> Result<bool> {
2757 Ok(true)
2758 }
2759 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2760 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2761 Ok(Box::new(cfg))
2762 }
2763}
2764
2765impl IsqModelLoader for DeepSeekV3Loader {
2766 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2767 let mut data = vec![
2768 Regex::new(r"lm_head\.(weight|bias)$")?,
2769 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2771 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2772 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2773 ];
2774 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2775 if cfg.q_lora_rank.is_some() {
2776 data.extend(vec![
2777 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2778 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2779 ]);
2780 } else {
2781 data.push(Regex::new(
2782 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2783 )?);
2784 }
2785 for layer_idx in 0..cfg.num_hidden_layers {
2786 if cfg.n_routed_experts.is_some()
2787 && layer_idx >= cfg.first_k_dense_replace
2788 && layer_idx % cfg.moe_layer_freq == 0
2789 {
2790 for i in 0..cfg.n_routed_experts.unwrap() {
2791 data.extend(vec![
2792 Regex::new(&format!(
2793 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2794 ))?,
2795 Regex::new(&format!(
2796 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2797 ))?,
2798 Regex::new(&format!(
2799 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2800 ))?,
2801 ]);
2802 }
2803 if cfg.n_shared_experts.is_some() {
2804 data.extend(vec![
2805 Regex::new(&format!(
2806 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2807 ))?,
2808 Regex::new(&format!(
2809 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2810 ))?,
2811 Regex::new(&format!(
2812 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2813 ))?,
2814 ]);
2815 }
2816 } else {
2817 data.extend(vec![
2818 Regex::new(&format!(
2819 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2820 ))?,
2821 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2822 Regex::new(&format!(
2823 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2824 ))?,
2825 ]);
2826 };
2827 }
2828 Ok(data)
2829 }
2830 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2831 self.isq_layer_regexes(config)
2832 }
2833
2834 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2835 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2836 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2837 for layer_idx in 0..cfg.num_hidden_layers {
2838 if cfg.n_routed_experts.is_some()
2839 && layer_idx >= cfg.first_k_dense_replace
2840 && layer_idx % cfg.moe_layer_freq == 0
2841 {
2842 for i in 0..cfg.n_routed_experts.unwrap() {
2843 data.extend(vec![
2844 Regex::new(&format!(
2845 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2846 ))?,
2847 Regex::new(&format!(
2848 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2849 ))?,
2850 Regex::new(&format!(
2851 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2852 ))?,
2853 ]);
2854 }
2855 if cfg.n_shared_experts.is_some() {
2856 data.extend(vec![
2857 Regex::new(&format!(
2858 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2859 ))?,
2860 Regex::new(&format!(
2861 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2862 ))?,
2863 Regex::new(&format!(
2864 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2865 ))?,
2866 ]);
2867 }
2868 } else {
2869 data.extend(vec![
2870 Regex::new(&format!(
2871 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2872 ))?,
2873 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2874 Regex::new(&format!(
2875 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2876 ))?,
2877 ]);
2878 };
2879 }
2880 Ok(data)
2881 }
2882 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2883 self.isq_layer_regexes_moqe(config)
2884 }
2885}
2886
2887impl DeviceMappedModelLoader for DeepSeekV3Loader {
2888 fn mapped_max_act_size_elems(
2889 &self,
2890 config: &str,
2891 params: &AutoDeviceMapParams,
2892 ) -> Result<usize> {
2893 let AutoDeviceMapParams::Text {
2894 max_seq_len,
2895 max_batch_size,
2896 } = params
2897 else {
2898 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2899 };
2900
2901 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2902
2903 Ok(
2904 max_batch_size
2905 * cfg.num_attention_heads
2906 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2907 )
2908 }
2909 fn non_mapped_max_act_size_elems(
2910 &self,
2911 _config: &str,
2912 _params: &AutoDeviceMapParams,
2913 ) -> Result<usize> {
2914 Ok(0)
2915 }
2916
2917 fn non_mapped_size_in_bytes(
2918 &self,
2919 config: &str,
2920 dtype: DType,
2921 weight_pack_factor: usize,
2922 _matformer_config: Option<&MatformerSliceConfig>,
2923 ) -> Result<usize> {
2924 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2925 let elems = {
2926 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2927 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2929 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2930 } else {
2931 0
2932 };
2933 let norm = cfg.hidden_size;
2934 embed_tokens + lm_head + norm
2935 };
2936 Ok(elems * dtype.size_in_bytes())
2937 }
2938
2939 fn layer_sizes_in_bytes(
2940 &self,
2941 config: &str,
2942 dtype: DType,
2943 weight_pack_factor: usize,
2944 _matformer_config: Option<&MatformerSliceConfig>,
2945 ) -> Result<Vec<usize>> {
2946 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2947 let mut per_layer_elems = Vec::new();
2948
2949 for layer_idx in 0..cfg.num_hidden_layers {
2950 let input_layernorm = cfg.hidden_size;
2951 let post_attention_layernorm = cfg.hidden_size;
2952
2953 let q_proj = match cfg.q_lora_rank {
2954 Some(lora_rank) => {
2955 let a = cfg.hidden_size * lora_rank;
2956 let norm = lora_rank;
2957 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2958 a + norm + b
2959 }
2960 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2961 };
2962 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2963 / weight_pack_factor
2964 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2965 let kv_a_layernorm = cfg.kv_lora_rank;
2966 let kv_b_proj = cfg.kv_lora_rank
2967 * cfg.num_attention_heads
2968 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2969 / weight_pack_factor;
2970 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2971 / weight_pack_factor
2972 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2973
2974 let moe_block = {
2975 let mut sum = 0;
2976 if cfg.n_routed_experts.is_some()
2977 && layer_idx >= cfg.first_k_dense_replace
2978 && layer_idx % cfg.moe_layer_freq == 0
2979 {
2980 let h_size = cfg.hidden_size;
2981 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2982 * cfg.n_routed_experts.unwrap();
2983 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2984 * cfg.n_routed_experts.unwrap();
2985 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2986 * cfg.n_routed_experts.unwrap();
2987 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2988 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2989 / weight_pack_factor;
2990 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2991 / weight_pack_factor;
2992 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2993 / weight_pack_factor;
2994 gate_proj + up_proj + down_proj
2995 } else {
2996 0
2997 };
2998 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2999 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3000 } else {
3001 let h_size = cfg.hidden_size;
3002 let i_size = cfg.intermediate_size;
3003 let gate_proj = h_size * i_size / weight_pack_factor;
3004 let up_proj = h_size * i_size / weight_pack_factor;
3005 let down_proj = i_size * h_size / weight_pack_factor;
3006 sum += gate_proj + up_proj + down_proj;
3007 }
3008 sum
3009 };
3010
3011 per_layer_elems.push(
3012 input_layernorm
3013 + post_attention_layernorm
3014 + q_proj
3015 + kv_a_layernorm
3016 + kv_a_proj_with_mqa
3017 + kv_b_proj
3018 + o_proj
3019 + moe_block,
3020 );
3021 }
3022
3023 Ok(per_layer_elems
3024 .into_iter()
3025 .map(|x| x * dtype.size_in_bytes())
3026 .collect())
3027 }
3028
3029 fn num_layers(&self, config: &str) -> Result<usize> {
3030 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3031 Ok(cfg.num_hidden_layers)
3032 }
3033
3034 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3035 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3036
3037 let cfg = ModelConfigMetadata {
3038 max_seq_len: cfg.max_position_embeddings,
3039 num_layers: cfg.num_hidden_layers,
3040 hidden_size: cfg.hidden_size,
3041 num_kv_heads: cfg.num_attention_heads,
3042 num_attn_heads: cfg.num_attention_heads,
3043 sliding_window: None,
3044 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3045 v_head_dim: cfg.v_head_dim,
3046 };
3047
3048 Ok(Box::new(cfg))
3049 }
3050}
3051
3052pub struct Qwen3Loader;
3056
3057impl NormalModelLoader for Qwen3Loader {
3058 fn load(
3059 &self,
3060 config: &str,
3061 vb: ShardedVarBuilder,
3062 normal_loading_metadata: NormalLoadingMetadata,
3063 attention_mechanism: AttentionImplementation,
3064 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3065 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3066
3067 Ok(Box::new(models::qwen3::Model::new(
3068 &cfg,
3069 vb,
3070 self.is_gptx(config)?,
3071 normal_loading_metadata,
3072 attention_mechanism,
3073 )?))
3074 }
3075 fn load_xlora(
3076 &self,
3077 _config: &str,
3078 _vb: ShardedVarBuilder,
3079 _lora_config: &[((String, String), LoraConfig)],
3080 _xlora_config: Option<XLoraConfig>,
3081 _xlora_ordering: Ordering,
3082 _normal_loading_metadata: NormalLoadingMetadata,
3083 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3084 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3085 todo!()
3086 }
3087 fn is_gptx(&self, _: &str) -> Result<bool> {
3088 Ok(true)
3089 }
3090 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3091 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3092
3093 Ok(Box::new(cfg))
3094 }
3095}
3096
3097impl IsqModelLoader for Qwen3Loader {
3098 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3099 Ok(vec![
3100 Regex::new(r"lm_head\.(weight|bias)$")?,
3101 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3103 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3104 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3105 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3106 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3108 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3109 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3110 ])
3111 }
3112 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3113 self.isq_layer_regexes(config)
3114 }
3115}
3116
3117impl DeviceMappedModelLoader for Qwen3Loader {
3118 fn mapped_max_act_size_elems(
3119 &self,
3120 config: &str,
3121 params: &AutoDeviceMapParams,
3122 ) -> Result<usize> {
3123 let AutoDeviceMapParams::Text {
3124 max_seq_len,
3125 max_batch_size,
3126 } = params
3127 else {
3128 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3129 };
3130
3131 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3132
3133 Ok(
3134 max_batch_size
3135 * cfg.num_attention_heads
3136 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3137 )
3138 }
3139 fn non_mapped_max_act_size_elems(
3140 &self,
3141 _config: &str,
3142 _params: &AutoDeviceMapParams,
3143 ) -> Result<usize> {
3144 Ok(0)
3145 }
3146
3147 fn non_mapped_size_in_bytes(
3148 &self,
3149 config: &str,
3150 dtype: DType,
3151 weight_pack_factor: usize,
3152 _matformer_config: Option<&MatformerSliceConfig>,
3153 ) -> Result<usize> {
3154 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3155 let elems = {
3156 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3157 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3159 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3160 } else {
3161 0
3162 };
3163 let norm = cfg.hidden_size;
3164 embed_tokens + lm_head + norm
3165 };
3166 Ok(elems * dtype.size_in_bytes())
3167 }
3168
3169 fn layer_sizes_in_bytes(
3170 &self,
3171 config: &str,
3172 dtype: DType,
3173 weight_pack_factor: usize,
3174 _matformer_config: Option<&MatformerSliceConfig>,
3175 ) -> Result<Vec<usize>> {
3176 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3177 let per_layer_elems = {
3178 let input_layernorm = cfg.hidden_size;
3179 let post_attention_layernorm = cfg.hidden_size;
3180
3181 let size_in = cfg.hidden_size;
3182 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3183 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3184 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3185 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3186 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3187 let o_proj = size_q * size_in / weight_pack_factor;
3188
3189 let h_size = cfg.hidden_size;
3190 let i_size = cfg.intermediate_size;
3191 let gate_proj = h_size * i_size / weight_pack_factor;
3192 let up_proj = h_size * i_size / weight_pack_factor;
3193 let down_proj = i_size * h_size / weight_pack_factor;
3194
3195 let q_norm = cfg.head_dim();
3196 let k_norm = cfg.head_dim();
3197
3198 input_layernorm
3199 + post_attention_layernorm
3200 + q_proj
3201 + k_proj
3202 + v_proj
3203 + o_proj
3204 + gate_proj
3205 + up_proj
3206 + down_proj
3207 + q_norm
3208 + k_norm
3209 };
3210 Ok(vec![
3211 per_layer_elems * dtype.size_in_bytes();
3212 cfg.num_hidden_layers
3213 ])
3214 }
3215
3216 fn num_layers(&self, config: &str) -> Result<usize> {
3217 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3218 Ok(cfg.num_hidden_layers)
3219 }
3220
3221 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3222 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3223
3224 let cfg = ModelConfigMetadata {
3225 max_seq_len: cfg.max_position_embeddings,
3226 num_layers: cfg.num_hidden_layers,
3227 hidden_size: cfg.hidden_size,
3228 num_kv_heads: cfg.num_key_value_heads,
3229 num_attn_heads: cfg.num_attention_heads,
3230 sliding_window: cfg.sliding_window,
3231 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3232 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3233 };
3234
3235 Ok(Box::new(cfg))
3236 }
3237}
3238
3239pub struct GLM4Loader;
3243
3244impl NormalModelLoader for GLM4Loader {
3245 fn load(
3246 &self,
3247 config: &str,
3248 vb: ShardedVarBuilder,
3249 normal_loading_metadata: NormalLoadingMetadata,
3250 attention_mechanism: AttentionImplementation,
3251 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3252 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3253
3254 Ok(Box::new(models::glm4::Model::new(
3255 &cfg,
3256 vb,
3257 self.is_gptx(config)?,
3258 normal_loading_metadata,
3259 attention_mechanism,
3260 )?))
3261 }
3262 fn load_xlora(
3263 &self,
3264 _config: &str,
3265 _vb: ShardedVarBuilder,
3266 _lora_config: &[((String, String), LoraConfig)],
3267 _xlora_config: Option<XLoraConfig>,
3268 _xlora_ordering: Ordering,
3269 _normal_loading_metadata: NormalLoadingMetadata,
3270 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3271 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3272 todo!()
3273 }
3274 fn is_gptx(&self, _: &str) -> Result<bool> {
3275 Ok(true)
3276 }
3277 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3278 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3279
3280 Ok(Box::new(cfg))
3281 }
3282}
3283
3284impl IsqModelLoader for GLM4Loader {
3285 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3286 Ok(vec![
3287 Regex::new(r"lm_head\.(weight|bias)$")?,
3288 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3290 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3291 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3292 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3293 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3295 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3296 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3297 ])
3298 }
3299 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3300 self.isq_layer_regexes(config)
3301 }
3302}
3303
3304impl DeviceMappedModelLoader for GLM4Loader {
3305 fn mapped_max_act_size_elems(
3306 &self,
3307 config: &str,
3308 params: &AutoDeviceMapParams,
3309 ) -> Result<usize> {
3310 let AutoDeviceMapParams::Text {
3311 max_seq_len,
3312 max_batch_size,
3313 } = params
3314 else {
3315 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3316 };
3317
3318 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3319
3320 Ok(
3321 max_batch_size
3322 * cfg.num_attention_heads
3323 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3324 )
3325 }
3326 fn non_mapped_max_act_size_elems(
3327 &self,
3328 _config: &str,
3329 _params: &AutoDeviceMapParams,
3330 ) -> Result<usize> {
3331 Ok(0)
3332 }
3333
3334 fn non_mapped_size_in_bytes(
3335 &self,
3336 config: &str,
3337 dtype: DType,
3338 weight_pack_factor: usize,
3339 _matformer_config: Option<&MatformerSliceConfig>,
3340 ) -> Result<usize> {
3341 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3342 let elems = {
3343 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3344 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3346 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3347 } else {
3348 0
3349 };
3350 let norm = cfg.hidden_size;
3351 embed_tokens + lm_head + norm
3352 };
3353 Ok(elems * dtype.size_in_bytes())
3354 }
3355
3356 fn layer_sizes_in_bytes(
3357 &self,
3358 config: &str,
3359 dtype: DType,
3360 weight_pack_factor: usize,
3361 _matformer_config: Option<&MatformerSliceConfig>,
3362 ) -> Result<Vec<usize>> {
3363 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3364 let per_layer_elems = {
3365 let input_layernorm = cfg.hidden_size;
3366 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3369 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3370 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3371 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3372 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3373 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3374 let o_proj = size_q * size_in / weight_pack_factor;
3375
3376 let h_size = cfg.hidden_size;
3377 let i_size = cfg.intermediate_size;
3378 let gate_proj = h_size * i_size / weight_pack_factor;
3379 let up_proj = h_size * i_size / weight_pack_factor;
3380 let down_proj = i_size * h_size / weight_pack_factor;
3381
3382 input_layernorm
3383 + post_attention_layernorm
3384 + q_proj
3385 + k_proj
3386 + v_proj
3387 + o_proj
3388 + gate_proj
3389 + up_proj
3390 + down_proj
3391 };
3392 Ok(vec![
3393 per_layer_elems * dtype.size_in_bytes();
3394 cfg.num_hidden_layers
3395 ])
3396 }
3397
3398 fn num_layers(&self, config: &str) -> Result<usize> {
3399 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3400 Ok(cfg.num_hidden_layers)
3401 }
3402
3403 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3404 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3405
3406 let cfg = ModelConfigMetadata {
3407 max_seq_len: cfg.max_position_embeddings,
3408 num_layers: cfg.num_hidden_layers,
3409 hidden_size: cfg.hidden_size,
3410 num_kv_heads: cfg.num_key_value_heads,
3411 num_attn_heads: cfg.num_attention_heads,
3412 sliding_window: cfg.sliding_window,
3413 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3414 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3415 };
3416
3417 Ok(Box::new(cfg))
3418 }
3419}
3420
3421pub struct Qwen3MoELoader;
3425
3426impl NormalModelLoader for Qwen3MoELoader {
3427 fn load(
3428 &self,
3429 config: &str,
3430 vb: ShardedVarBuilder,
3431 normal_loading_metadata: NormalLoadingMetadata,
3432 attention_mechanism: AttentionImplementation,
3433 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3434 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3435
3436 Ok(Box::new(models::qwen3_moe::Model::new(
3437 &cfg,
3438 vb,
3439 self.is_gptx(config)?,
3440 normal_loading_metadata,
3441 attention_mechanism,
3442 )?))
3443 }
3444 fn load_xlora(
3445 &self,
3446 _config: &str,
3447 _vb: ShardedVarBuilder,
3448 _lora_config: &[((String, String), LoraConfig)],
3449 _xlora_config: Option<XLoraConfig>,
3450 _xlora_ordering: Ordering,
3451 _normal_loading_metadata: NormalLoadingMetadata,
3452 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3453 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3454 todo!()
3455 }
3456 fn is_gptx(&self, _: &str) -> Result<bool> {
3457 Ok(true)
3458 }
3459 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3460 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3461
3462 Ok(Box::new(cfg))
3463 }
3464}
3465
3466impl IsqModelLoader for Qwen3MoELoader {
3467 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3468 Ok(vec![
3469 Regex::new(r"lm_head\.(weight|bias)$")?,
3470 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3472 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3473 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3474 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3475 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3477 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3478 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3479 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3481 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3482 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3483 ])
3484 }
3485 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3486 self.isq_layer_regexes(config)
3487 }
3488 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3489 self.isq_layer_regexes_moqe(config)
3490 }
3491}
3492
3493impl DeviceMappedModelLoader for Qwen3MoELoader {
3494 fn mapped_max_act_size_elems(
3495 &self,
3496 config: &str,
3497 params: &AutoDeviceMapParams,
3498 ) -> Result<usize> {
3499 let AutoDeviceMapParams::Text {
3500 max_seq_len,
3501 max_batch_size,
3502 } = params
3503 else {
3504 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3505 };
3506
3507 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3508
3509 Ok(
3510 max_batch_size
3511 * cfg.num_attention_heads
3512 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3513 )
3514 }
3515 fn non_mapped_max_act_size_elems(
3516 &self,
3517 _config: &str,
3518 _params: &AutoDeviceMapParams,
3519 ) -> Result<usize> {
3520 Ok(0)
3521 }
3522
3523 fn non_mapped_size_in_bytes(
3524 &self,
3525 config: &str,
3526 dtype: DType,
3527 weight_pack_factor: usize,
3528 _matformer_config: Option<&MatformerSliceConfig>,
3529 ) -> Result<usize> {
3530 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3531 let elems = {
3532 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3533 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3535 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3536 } else {
3537 0
3538 };
3539 let norm = cfg.hidden_size;
3540 embed_tokens + lm_head + norm
3541 };
3542 Ok(elems * dtype.size_in_bytes())
3543 }
3544
3545 fn layer_sizes_in_bytes(
3546 &self,
3547 config: &str,
3548 dtype: DType,
3549 weight_pack_factor: usize,
3550 _matformer_config: Option<&MatformerSliceConfig>,
3551 ) -> Result<Vec<usize>> {
3552 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3553
3554 let mut layer_sizes_in_bytes = Vec::new();
3555 for layer_idx in 0..cfg.num_hidden_layers {
3556 let input_layernorm = cfg.hidden_size;
3557 let post_attention_layernorm = cfg.hidden_size;
3558
3559 let size_in = cfg.hidden_size;
3560 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3561 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3562 let q_proj = size_in * size_q / weight_pack_factor;
3563 let k_proj = size_in * size_kv / weight_pack_factor;
3564 let v_proj = size_in * size_kv / weight_pack_factor;
3565 let o_proj = size_q * size_in / weight_pack_factor;
3566
3567 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3568 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3569 {
3570 let gate_size = cfg.hidden_size * cfg.num_experts;
3571 let expert_size = {
3572 let h_size = cfg.hidden_size;
3573 let i_size = cfg.moe_intermediate_size;
3574 let gate_proj = h_size * i_size / weight_pack_factor;
3575 let up_proj = h_size * i_size / weight_pack_factor;
3576 let down_proj = i_size * h_size / weight_pack_factor;
3577 gate_proj + up_proj + down_proj
3578 };
3579 expert_size * cfg.num_experts + gate_size
3580 } else {
3581 let h_size = cfg.hidden_size;
3582 let i_size = cfg.intermediate_size;
3583 let gate_proj = h_size * i_size / weight_pack_factor;
3584 let up_proj = h_size * i_size / weight_pack_factor;
3585 let down_proj = i_size * h_size / weight_pack_factor;
3586 gate_proj + up_proj + down_proj
3587 };
3588
3589 let q_norm = cfg.head_dim();
3590 let k_norm = cfg.head_dim();
3591
3592 let size_elems = input_layernorm
3593 + post_attention_layernorm
3594 + q_proj
3595 + k_proj
3596 + v_proj
3597 + o_proj
3598 + mlp_size
3599 + q_norm
3600 + k_norm;
3601
3602 let size_in_bytes = size_elems * dtype.size_in_bytes();
3603 layer_sizes_in_bytes.push(size_in_bytes);
3604 }
3605
3606 Ok(layer_sizes_in_bytes)
3607 }
3608
3609 fn num_layers(&self, config: &str) -> Result<usize> {
3610 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3611 Ok(cfg.num_hidden_layers)
3612 }
3613
3614 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3615 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3616
3617 let cfg = ModelConfigMetadata {
3618 max_seq_len: cfg.max_position_embeddings,
3619 num_layers: cfg.num_hidden_layers,
3620 hidden_size: cfg.hidden_size,
3621 num_kv_heads: cfg.num_key_value_heads,
3622 num_attn_heads: cfg.num_attention_heads,
3623 sliding_window: cfg.sliding_window,
3624 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3625 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3626 };
3627
3628 Ok(Box::new(cfg))
3629 }
3630}
3631
3632pub struct SmolLm3Loader;
3638
3639impl NormalModelLoader for SmolLm3Loader {
3640 fn load(
3641 &self,
3642 config: &str,
3643 vb: ShardedVarBuilder,
3644 normal_loading_metadata: NormalLoadingMetadata,
3645 attention_mechanism: AttentionImplementation,
3646 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3647 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3648
3649 Ok(Box::new(models::smollm3::SmolLm3::new(
3650 &cfg,
3651 vb,
3652 self.is_gptx(config)?,
3653 normal_loading_metadata,
3654 attention_mechanism,
3655 )?))
3656 }
3657 fn load_xlora(
3658 &self,
3659 _config: &str,
3660 _vb: ShardedVarBuilder,
3661 _lora_config: &[((String, String), LoraConfig)],
3662 _xlora_config: Option<XLoraConfig>,
3663 _xlora_ordering: Ordering,
3664 _normal_loading_metadata: NormalLoadingMetadata,
3665 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3666 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3667 todo!()
3668 }
3669 fn is_gptx(&self, _: &str) -> Result<bool> {
3670 Ok(true)
3671 }
3672 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3673 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3674 Ok(Box::new(cfg))
3675 }
3676}
3677
3678impl IsqModelLoader for SmolLm3Loader {
3679 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3680 Ok(vec![
3681 Regex::new(r"lm_head\.(weight|bias)$")?,
3682 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3684 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3685 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3686 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3687 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3689 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3690 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3691 ])
3692 }
3693 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3694 self.isq_layer_regexes(config)
3695 }
3696}
3697
3698impl DeviceMappedModelLoader for SmolLm3Loader {
3699 fn mapped_max_act_size_elems(
3700 &self,
3701 config: &str,
3702 params: &AutoDeviceMapParams,
3703 ) -> Result<usize> {
3704 let AutoDeviceMapParams::Text {
3705 max_seq_len,
3706 max_batch_size,
3707 } = params
3708 else {
3709 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3710 };
3711
3712 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3713
3714 Ok(
3715 max_batch_size
3716 * cfg.num_attention_heads
3717 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3718 )
3719 }
3720 fn non_mapped_max_act_size_elems(
3721 &self,
3722 _config: &str,
3723 _params: &AutoDeviceMapParams,
3724 ) -> Result<usize> {
3725 Ok(0)
3726 }
3727
3728 fn non_mapped_size_in_bytes(
3729 &self,
3730 config: &str,
3731 dtype: DType,
3732 weight_pack_factor: usize,
3733 _matformer_config: Option<&MatformerSliceConfig>,
3734 ) -> Result<usize> {
3735 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3736
3737 let elems = {
3738 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3739 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3741 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3742 } else {
3743 0
3744 };
3745 let norm = cfg.hidden_size;
3746 embed_tokens + lm_head + norm
3747 };
3748 Ok(elems * dtype.size_in_bytes())
3749 }
3750
3751 fn layer_sizes_in_bytes(
3752 &self,
3753 config: &str,
3754 dtype: DType,
3755 weight_pack_factor: usize,
3756 _matformer_config: Option<&MatformerSliceConfig>,
3757 ) -> Result<Vec<usize>> {
3758 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3759
3760 let per_layer_elems = {
3761 let input_layernorm = cfg.hidden_size;
3762 let post_attention_layernorm = cfg.hidden_size;
3763
3764 let size_in = cfg.hidden_size;
3765 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3766 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3767 let q_proj = size_in * size_q / weight_pack_factor;
3768 let k_proj = size_in * size_kv / weight_pack_factor;
3769 let v_proj = size_in * size_kv / weight_pack_factor;
3770 let o_proj = size_q * size_in / weight_pack_factor;
3771
3772 let h_size = cfg.hidden_size;
3773 let i_size = cfg.intermediate_size;
3774 let gate_proj = h_size * i_size / weight_pack_factor;
3775 let up_proj = h_size * i_size / weight_pack_factor;
3776 let down_proj = i_size * h_size / weight_pack_factor;
3777
3778 input_layernorm
3779 + post_attention_layernorm
3780 + q_proj
3781 + k_proj
3782 + v_proj
3783 + o_proj
3784 + gate_proj
3785 + up_proj
3786 + down_proj
3787 };
3788 Ok(vec![
3789 per_layer_elems * dtype.size_in_bytes();
3790 cfg.num_hidden_layers
3791 ])
3792 }
3793
3794 fn num_layers(&self, config: &str) -> Result<usize> {
3795 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3796
3797 Ok(cfg.num_hidden_layers)
3798 }
3799 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3800 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3801
3802 let cfg = ModelConfigMetadata {
3803 max_seq_len: cfg.max_position_embeddings,
3804 num_layers: cfg.num_hidden_layers,
3805 hidden_size: cfg.hidden_size,
3806 num_kv_heads: cfg.num_key_value_heads,
3807 num_attn_heads: cfg.num_attention_heads,
3808 sliding_window: None,
3809 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3810 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3811 };
3812
3813 Ok(Box::new(cfg))
3814 }
3815}
3816
3817pub struct GraniteMoeHybridLoader;
3823
3824impl NormalModelLoader for GraniteMoeHybridLoader {
3825 fn load(
3826 &self,
3827 config: &str,
3828 vb: ShardedVarBuilder,
3829 normal_loading_metadata: NormalLoadingMetadata,
3830 attention_mechanism: AttentionImplementation,
3831 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3832 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3833
3834 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
3835 &cfg,
3836 vb,
3837 self.is_gptx(config)?,
3838 normal_loading_metadata,
3839 attention_mechanism,
3840 )?))
3841 }
3842 fn load_xlora(
3843 &self,
3844 _config: &str,
3845 _vb: ShardedVarBuilder,
3846 _lora_config: &[((String, String), LoraConfig)],
3847 _xlora_config: Option<XLoraConfig>,
3848 _xlora_ordering: Ordering,
3849 _normal_loading_metadata: NormalLoadingMetadata,
3850 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3851 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3852 todo!()
3853 }
3854 fn is_gptx(&self, _: &str) -> Result<bool> {
3855 Ok(true)
3856 }
3857 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3858 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3859 Ok(Box::new(cfg))
3860 }
3861}
3862
3863impl IsqModelLoader for GraniteMoeHybridLoader {
3864 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3865 Ok(vec![
3866 Regex::new(r"lm_head\.(weight|bias)$")?,
3867 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3869 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3870 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3871 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3872 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
3874 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
3875 ])
3876 }
3877 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3878 self.isq_layer_regexes(config)
3879 }
3880}
3881
3882impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
3883 fn mapped_max_act_size_elems(
3884 &self,
3885 config: &str,
3886 params: &AutoDeviceMapParams,
3887 ) -> Result<usize> {
3888 let AutoDeviceMapParams::Text {
3889 max_seq_len,
3890 max_batch_size,
3891 } = params
3892 else {
3893 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3894 };
3895
3896 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3897
3898 Ok(
3899 max_batch_size
3900 * cfg.num_attention_heads
3901 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3902 )
3903 }
3904 fn non_mapped_max_act_size_elems(
3905 &self,
3906 _config: &str,
3907 _params: &AutoDeviceMapParams,
3908 ) -> Result<usize> {
3909 Ok(0)
3910 }
3911
3912 fn non_mapped_size_in_bytes(
3913 &self,
3914 config: &str,
3915 dtype: DType,
3916 weight_pack_factor: usize,
3917 _matformer_config: Option<&MatformerSliceConfig>,
3918 ) -> Result<usize> {
3919 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3920
3921 let elems = {
3922 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3923 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3925 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3926 } else {
3927 0
3928 };
3929 let norm = cfg.hidden_size;
3930 embed_tokens + lm_head + norm
3931 };
3932 Ok(elems * dtype.size_in_bytes())
3933 }
3934
3935 fn layer_sizes_in_bytes(
3936 &self,
3937 config: &str,
3938 dtype: DType,
3939 weight_pack_factor: usize,
3940 _matformer_config: Option<&MatformerSliceConfig>,
3941 ) -> Result<Vec<usize>> {
3942 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3943
3944 let per_layer_elems = {
3945 let input_layernorm = cfg.hidden_size;
3946 let post_attention_layernorm = cfg.hidden_size;
3947
3948 let size_in = cfg.hidden_size;
3949 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3950 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
3951 let q_proj = size_in * size_q / weight_pack_factor;
3952 let k_proj = size_in * size_kv / weight_pack_factor;
3953 let v_proj = size_in * size_kv / weight_pack_factor;
3954 let o_proj = size_q * size_in / weight_pack_factor;
3955
3956 let h_size = cfg.hidden_size;
3957 let shared_i_size = cfg.shared_intermediate_size();
3958 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
3960 let output_linear = shared_i_size * h_size / weight_pack_factor;
3961
3962 input_layernorm
3963 + post_attention_layernorm
3964 + q_proj
3965 + k_proj
3966 + v_proj
3967 + o_proj
3968 + input_linear
3969 + output_linear
3970 };
3971 Ok(vec![
3972 per_layer_elems * dtype.size_in_bytes();
3973 cfg.num_hidden_layers
3974 ])
3975 }
3976
3977 fn num_layers(&self, config: &str) -> Result<usize> {
3978 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3979
3980 Ok(cfg.num_hidden_layers)
3981 }
3982 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3983 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3984
3985 let cfg = ModelConfigMetadata {
3986 max_seq_len: cfg.max_position_embeddings,
3987 num_layers: cfg.num_hidden_layers,
3988 hidden_size: cfg.hidden_size,
3989 num_kv_heads: cfg.num_key_value_heads(),
3990 num_attn_heads: cfg.num_attention_heads,
3991 sliding_window: None,
3992 k_head_dim: cfg.head_dim(),
3993 v_head_dim: cfg.head_dim(),
3994 };
3995
3996 Ok(Box::new(cfg))
3997 }
3998}