1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "qwen3moe")]
175 Qwen3Moe,
176 #[serde(rename = "smollm3")]
177 SmolLm3,
178 #[serde(rename = "granitemoehybrid")]
179 GraniteMoeHybrid,
180 #[serde(rename = "gpt_oss")]
181 GptOss,
182}
183
184impl NormalLoaderType {
186 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
187 match name {
188 "MistralForCausalLM" => Ok(Self::Mistral),
189 "MixtralForCausalLM" => Ok(Self::Mixtral),
190 "GemmaForCausalLM" => Ok(Self::Gemma),
191 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
192 "PhiForCausalLM" => Ok(Self::Phi2),
193 "Phi3ForCausalLM" => Ok(Self::Phi3),
194 "LlamaForCausalLM" => Ok(Self::Llama),
195 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
196 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
197 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
198 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
199 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
200 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
201 "Glm4ForCausalLM" => Ok(Self::GLM4),
202 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
203 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
204 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
205 "GptOssForCausalLM" => Ok(Self::GptOss),
206 other => anyhow::bail!(
207 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
208 ),
209 }
210 }
211}
212
213impl FromStr for NormalLoaderType {
214 type Err = String;
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
216 match s {
217 "mistral" => Ok(Self::Mistral),
218 "gemma" => Ok(Self::Gemma),
219 "mixtral" => Ok(Self::Mixtral),
220 "llama" => Ok(Self::Llama),
221 "phi2" => Ok(Self::Phi2),
222 "phi3" => Ok(Self::Phi3),
223 "qwen2" => Ok(Self::Qwen2),
224 "gemma2" => Ok(Self::Gemma2),
225 "starcoder2" => Ok(Self::Starcoder2),
226 "phi3.5moe" => Ok(Self::Phi3_5MoE),
227 "deepseekv2" => Ok(Self::DeepSeekV2),
228 "deepseekv3" => Ok(Self::DeepSeekV3),
229 "qwen3" => Ok(Self::Qwen3),
230 "glm4" => Ok(Self::GLM4),
231 "qwen3moe" => Ok(Self::Qwen3Moe),
232 "smollm3" => Ok(Self::SmolLm3),
233 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
234 "gpt_oss" => Ok(Self::GptOss),
235 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
236 }
237 }
238}
239
240impl Display for NormalLoaderType {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 Self::Gemma => write!(f, "gemma"),
244 Self::Gemma2 => write!(f, "gemma2"),
245 Self::Llama => write!(f, "llama"),
246 Self::Mistral => write!(f, "mistral"),
247 Self::Mixtral => write!(f, "mixtral"),
248 Self::Phi2 => write!(f, "phi2"),
249 Self::Phi3 => write!(f, "phi3"),
250 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
251 Self::Qwen2 => write!(f, "qwen2"),
252 Self::Starcoder2 => write!(f, "starcoder2"),
253 Self::DeepSeekV2 => write!(f, "deepseekv2"),
254 Self::DeepSeekV3 => write!(f, "deepseekv3"),
255 Self::Qwen3 => write!(f, "qwen3"),
256 Self::GLM4 => write!(f, "glm4"),
257 Self::Qwen3Moe => write!(f, "qwen3moe"),
258 Self::SmolLm3 => write!(f, "smollm3"),
259 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
260 Self::GptOss => write!(f, "gpt_oss"),
261 }
262 }
263}
264
265macro_rules! bias_if {
266 ($cond:expr, $size:expr) => {
267 if $cond {
268 $size
269 } else {
270 0
271 }
272 };
273}
274
275pub struct AutoNormalLoader;
277
278#[derive(Deserialize)]
279struct AutoNormalLoaderConfig {
280 architectures: Vec<String>,
281}
282
283impl AutoNormalLoader {
284 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
285 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
286 if auto_cfg.architectures.len() != 1 {
287 anyhow::bail!("Expected to have one name for `architectures` config field.")
288 }
289
290 let name = &auto_cfg.architectures[0];
291
292 let tp = NormalLoaderType::from_causal_lm_name(name)?;
293
294 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
295
296 match tp {
297 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
298 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
299 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
300 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
301 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
302 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
303 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
304 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
305 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
306 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
307 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
308 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
309 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
310 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
311 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
312 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
313 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
314 NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
315 }
316 }
317}
318
319impl NormalModelLoader for AutoNormalLoader {
320 fn load(
321 &self,
322 config: &str,
323 vb: ShardedVarBuilder,
324 normal_loading_metadata: NormalLoadingMetadata,
325 attention_mechanism: AttentionImplementation,
326 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
327 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
328 }
329 fn load_xlora(
330 &self,
331 config: &str,
332 vb: ShardedVarBuilder,
333 lora_config: &[((String, String), LoraConfig)],
334 xlora_config: Option<XLoraConfig>,
335 xlora_ordering: Ordering,
336 normal_loading_metadata: NormalLoadingMetadata,
337 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
338 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339 Self::get_loader(config)?.load_xlora(
340 config,
341 vb,
342 lora_config,
343 xlora_config,
344 xlora_ordering,
345 normal_loading_metadata,
346 preload_adapters,
347 )
348 }
349 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
350 Self::get_loader(config)?.get_config_repr(config)
351 }
352 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
353 Self::get_loader(config)?.supports_paged_attention(config)
354 }
355 fn is_gptx(&self, config: &str) -> Result<bool> {
356 Self::get_loader(config)?.is_gptx(config)
357 }
358}
359
360impl IsqModelLoader for AutoNormalLoader {
361 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
362 Self::get_loader(config)?.immediate_isq_predicates(config)
363 }
364 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
365 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
366 }
367 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
368 Self::get_loader(config)?.isq_layer_regexes(config)
369 }
370 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
371 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
372 }
373}
374
375impl DeviceMappedModelLoader for AutoNormalLoader {
376 fn non_mapped_size_in_bytes(
377 &self,
378 config: &str,
379 dtype: DType,
380 weight_pack_factor: usize,
381 _matformer_config: Option<&MatformerSliceConfig>,
382 ) -> Result<usize> {
383 Self::get_loader(config)?.non_mapped_size_in_bytes(
384 config,
385 dtype,
386 weight_pack_factor,
387 _matformer_config,
388 )
389 }
390 fn num_layers(&self, config: &str) -> Result<usize> {
391 Self::get_loader(config)?.num_layers(config)
392 }
393 fn layer_sizes_in_bytes(
394 &self,
395 config: &str,
396 dtype: DType,
397 weight_pack_factor: usize,
398 _matformer_config: Option<&MatformerSliceConfig>,
399 ) -> Result<Vec<usize>> {
400 Self::get_loader(config)?.layer_sizes_in_bytes(
401 config,
402 dtype,
403 weight_pack_factor,
404 _matformer_config,
405 )
406 }
407 fn mapped_max_act_size_elems(
408 &self,
409 config: &str,
410 params: &super::AutoDeviceMapParams,
411 ) -> Result<usize> {
412 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
413 }
414 fn non_mapped_max_act_size_elems(
415 &self,
416 _config: &str,
417 _params: &AutoDeviceMapParams,
418 ) -> Result<usize> {
419 Ok(0)
420 }
421 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
422 Self::get_loader(config)?.model_config(config)
423 }
424}
425
426pub struct MistralLoader;
429
430impl NormalModelLoader for MistralLoader {
431 fn load(
432 &self,
433 config: &str,
434 vb: ShardedVarBuilder,
435 normal_loading_metadata: NormalLoadingMetadata,
436 attention_mechanism: AttentionImplementation,
437 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
438 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
439 Ok(Box::new(models::mistral::Model::new(
440 &cfg,
441 vb,
442 self.is_gptx(config)?,
443 normal_loading_metadata,
444 attention_mechanism,
445 )?))
446 }
447 fn load_xlora(
448 &self,
449 config: &str,
450 vb: ShardedVarBuilder,
451 lora_config: &[((String, String), LoraConfig)],
452 xlora_config: Option<XLoraConfig>,
453 xlora_ordering: Ordering,
454 normal_loading_metadata: NormalLoadingMetadata,
455 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
456 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
457 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
458 Ok(Box::new(xlora_models::XLoraMistral::new(
459 &cfg,
460 vb,
461 lora_config,
462 xlora_config,
463 xlora_ordering,
464 self.is_gptx(config)?,
465 normal_loading_metadata,
466 preload_adapters,
467 )?))
468 }
469 fn is_gptx(&self, _: &str) -> Result<bool> {
470 Ok(true)
471 }
472 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
473 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
474 Ok(Box::new(cfg))
475 }
476}
477
478impl IsqModelLoader for MistralLoader {
479 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
480 Ok(vec![
481 Regex::new(r"lm_head\.(weight|bias)$")?,
482 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
484 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
485 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
486 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
487 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
489 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
490 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
491 ])
492 }
493 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
494 self.isq_layer_regexes(config)
495 }
496}
497
498impl DeviceMappedModelLoader for MistralLoader {
499 fn mapped_max_act_size_elems(
500 &self,
501 config: &str,
502 params: &AutoDeviceMapParams,
503 ) -> Result<usize> {
504 let AutoDeviceMapParams::Text {
505 max_seq_len,
506 max_batch_size,
507 } = params
508 else {
509 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
510 };
511
512 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
513
514 Ok(
515 max_batch_size
516 * cfg.num_attention_heads
517 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
518 )
519 }
520 fn non_mapped_max_act_size_elems(
521 &self,
522 _config: &str,
523 _params: &AutoDeviceMapParams,
524 ) -> Result<usize> {
525 Ok(0)
526 }
527
528 fn non_mapped_size_in_bytes(
529 &self,
530 config: &str,
531 dtype: DType,
532 weight_pack_factor: usize,
533 _matformer_config: Option<&MatformerSliceConfig>,
534 ) -> Result<usize> {
535 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
536
537 let elems = {
538 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
539 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
541 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
542 } else {
543 0
544 };
545 let norm = cfg.hidden_size;
546 embed_tokens + lm_head + norm
547 };
548 Ok(elems * dtype.size_in_bytes())
549 }
550
551 fn layer_sizes_in_bytes(
552 &self,
553 config: &str,
554 dtype: DType,
555 weight_pack_factor: usize,
556 _matformer_config: Option<&MatformerSliceConfig>,
557 ) -> Result<Vec<usize>> {
558 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
559
560 let per_layer_elems = {
561 let input_layernorm = cfg.hidden_size;
562 let post_attention_layernorm = cfg.hidden_size;
563
564 let size_in = cfg.hidden_size;
565 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
566 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
567 let q_proj = size_in * size_q / weight_pack_factor;
568 let k_proj = size_in * size_kv / weight_pack_factor;
569 let v_proj = size_in * size_kv / weight_pack_factor;
570 let o_proj = size_q * size_in / weight_pack_factor;
571
572 let h_size = cfg.hidden_size;
573 let i_size = cfg.intermediate_size;
574 let gate_proj = h_size * i_size / weight_pack_factor;
575 let up_proj = h_size * i_size / weight_pack_factor;
576 let down_proj = i_size * h_size / weight_pack_factor;
577
578 input_layernorm
579 + post_attention_layernorm
580 + q_proj
581 + k_proj
582 + v_proj
583 + o_proj
584 + gate_proj
585 + up_proj
586 + down_proj
587 };
588 Ok(vec![
589 per_layer_elems * dtype.size_in_bytes();
590 cfg.num_hidden_layers
591 ])
592 }
593
594 fn num_layers(&self, config: &str) -> Result<usize> {
595 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
596 Ok(cfg.num_hidden_layers)
597 }
598
599 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
600 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
601
602 let cfg = ModelConfigMetadata {
603 max_seq_len: cfg.max_position_embeddings,
604 num_layers: cfg.num_hidden_layers,
605 hidden_size: cfg.hidden_size,
606 num_kv_heads: cfg.num_key_value_heads,
607 num_attn_heads: cfg.num_attention_heads,
608 sliding_window: cfg.sliding_window,
609 k_head_dim: cfg.head_dim(),
610 v_head_dim: cfg.head_dim(),
611 };
612
613 Ok(Box::new(cfg))
614 }
615}
616
617pub struct GemmaLoader;
623
624impl NormalModelLoader for GemmaLoader {
625 fn load(
626 &self,
627 config: &str,
628 vb: ShardedVarBuilder,
629 normal_loading_metadata: NormalLoadingMetadata,
630 attention_mechanism: AttentionImplementation,
631 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
632 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
633
634 Ok(Box::new(models::gemma::Model::new(
635 &cfg,
636 vb,
637 self.is_gptx(config)?,
638 normal_loading_metadata,
639 attention_mechanism,
640 )?))
641 }
642 fn load_xlora(
643 &self,
644 config: &str,
645 vb: ShardedVarBuilder,
646 lora_config: &[((String, String), LoraConfig)],
647 xlora_config: Option<XLoraConfig>,
648 xlora_ordering: Ordering,
649 normal_loading_metadata: NormalLoadingMetadata,
650 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
651 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
652 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
653
654 Ok(Box::new(xlora_models::XLoraGemma::new(
655 &cfg,
656 vb,
657 lora_config,
658 xlora_config,
659 xlora_ordering,
660 self.is_gptx(config)?,
661 normal_loading_metadata,
662 preload_adapters,
663 )?))
664 }
665 fn is_gptx(&self, _: &str) -> Result<bool> {
666 Ok(true)
667 }
668 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
669 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
670 Ok(Box::new(cfg))
671 }
672}
673
674impl IsqModelLoader for GemmaLoader {
675 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
676 Ok(vec![
677 Regex::new(r"lm_head\.(weight|bias)$")?,
678 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
680 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
681 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
682 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
683 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
685 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
686 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
687 ])
688 }
689 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
690 self.isq_layer_regexes(config)
691 }
692}
693
694impl DeviceMappedModelLoader for GemmaLoader {
695 fn mapped_max_act_size_elems(
696 &self,
697 config: &str,
698 params: &AutoDeviceMapParams,
699 ) -> Result<usize> {
700 let AutoDeviceMapParams::Text {
701 max_seq_len,
702 max_batch_size,
703 } = params
704 else {
705 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
706 };
707
708 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
709
710 Ok(
711 max_batch_size
712 * cfg.num_attention_heads
713 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
714 )
715 }
716 fn non_mapped_max_act_size_elems(
717 &self,
718 _config: &str,
719 _params: &AutoDeviceMapParams,
720 ) -> Result<usize> {
721 Ok(0)
722 }
723
724 fn non_mapped_size_in_bytes(
725 &self,
726 config: &str,
727 dtype: DType,
728 weight_pack_factor: usize,
729 _matformer_config: Option<&MatformerSliceConfig>,
730 ) -> Result<usize> {
731 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
732
733 let elems = {
734 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
735 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
737 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
738 } else {
739 0
740 };
741 let norm = cfg.hidden_size;
742 embed_tokens + lm_head + norm
743 };
744 Ok(elems * dtype.size_in_bytes())
745 }
746
747 fn layer_sizes_in_bytes(
748 &self,
749 config: &str,
750 dtype: DType,
751 weight_pack_factor: usize,
752 _matformer_config: Option<&MatformerSliceConfig>,
753 ) -> Result<Vec<usize>> {
754 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
755
756 let per_layer_elems = {
757 let input_layernorm = cfg.hidden_size;
758 let post_attention_layernorm = cfg.hidden_size;
759
760 let size_in = cfg.hidden_size;
761 let size_q = cfg.head_dim * cfg.num_attention_heads;
762 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
763 let q_proj =
764 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
765 let k_proj =
766 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
767 let v_proj =
768 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
769 let o_proj =
770 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
771
772 let h_size = cfg.hidden_size;
773 let i_size = cfg.intermediate_size;
774 let gate_proj = h_size * i_size / weight_pack_factor;
775 let up_proj = h_size * i_size / weight_pack_factor;
776 let down_proj = i_size * h_size / weight_pack_factor;
777
778 input_layernorm
779 + post_attention_layernorm
780 + q_proj
781 + k_proj
782 + v_proj
783 + o_proj
784 + gate_proj
785 + up_proj
786 + down_proj
787 };
788 Ok(vec![
789 per_layer_elems * dtype.size_in_bytes();
790 cfg.num_hidden_layers
791 ])
792 }
793
794 fn num_layers(&self, config: &str) -> Result<usize> {
795 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
796 Ok(cfg.num_hidden_layers)
797 }
798
799 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
800 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
801
802 let cfg = ModelConfigMetadata {
803 max_seq_len: cfg.max_position_embeddings,
804 num_layers: cfg.num_hidden_layers,
805 hidden_size: cfg.hidden_size,
806 num_kv_heads: cfg.num_key_value_heads,
807 num_attn_heads: cfg.num_attention_heads,
808 sliding_window: None,
809 k_head_dim: cfg.head_dim,
810 v_head_dim: cfg.head_dim,
811 };
812
813 Ok(Box::new(cfg))
814 }
815}
816
817pub struct LlamaLoader;
823
824impl NormalModelLoader for LlamaLoader {
825 fn load(
826 &self,
827 config: &str,
828 vb: ShardedVarBuilder,
829 normal_loading_metadata: NormalLoadingMetadata,
830 attention_mechanism: AttentionImplementation,
831 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
832 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
833
834 Ok(Box::new(models::llama::Llama::new(
835 &cfg,
836 vb,
837 self.is_gptx(config)?,
838 normal_loading_metadata,
839 attention_mechanism,
840 )?))
841 }
842 fn load_xlora(
843 &self,
844 config: &str,
845 vb: ShardedVarBuilder,
846 lora_config: &[((String, String), LoraConfig)],
847 xlora_config: Option<XLoraConfig>,
848 xlora_ordering: Ordering,
849 normal_loading_metadata: NormalLoadingMetadata,
850 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
851 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
852 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
853
854 Ok(Box::new(xlora_models::XLoraLlama::new(
855 &cfg,
856 vb,
857 lora_config,
858 xlora_config,
859 xlora_ordering,
860 self.is_gptx(config)?,
861 normal_loading_metadata,
862 preload_adapters,
863 )?))
864 }
865 fn is_gptx(&self, _: &str) -> Result<bool> {
866 Ok(true)
867 }
868 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
869 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
870 Ok(Box::new(cfg))
871 }
872}
873
874impl IsqModelLoader for LlamaLoader {
875 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
876 Ok(vec![
877 Regex::new(r"lm_head\.(weight|bias)$")?,
878 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
880 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
881 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
882 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
883 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
885 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
886 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
887 ])
888 }
889 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
890 self.isq_layer_regexes(config)
891 }
892}
893
894impl DeviceMappedModelLoader for LlamaLoader {
895 fn mapped_max_act_size_elems(
896 &self,
897 config: &str,
898 params: &AutoDeviceMapParams,
899 ) -> Result<usize> {
900 let AutoDeviceMapParams::Text {
901 max_seq_len,
902 max_batch_size,
903 } = params
904 else {
905 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
906 };
907
908 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
909
910 Ok(
911 max_batch_size
912 * cfg.num_attention_heads
913 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
914 )
915 }
916 fn non_mapped_max_act_size_elems(
917 &self,
918 _config: &str,
919 _params: &AutoDeviceMapParams,
920 ) -> Result<usize> {
921 Ok(0)
922 }
923
924 fn non_mapped_size_in_bytes(
925 &self,
926 config: &str,
927 dtype: DType,
928 weight_pack_factor: usize,
929 _matformer_config: Option<&MatformerSliceConfig>,
930 ) -> Result<usize> {
931 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
932
933 let elems = {
934 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
935 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
937 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
938 } else {
939 0
940 };
941 let norm = cfg.hidden_size;
942 embed_tokens + lm_head + norm
943 };
944 Ok(elems * dtype.size_in_bytes())
945 }
946
947 fn layer_sizes_in_bytes(
948 &self,
949 config: &str,
950 dtype: DType,
951 weight_pack_factor: usize,
952 _matformer_config: Option<&MatformerSliceConfig>,
953 ) -> Result<Vec<usize>> {
954 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956 let per_layer_elems = {
957 let input_layernorm = cfg.hidden_size;
958 let post_attention_layernorm = cfg.hidden_size;
959
960 let size_in = cfg.hidden_size;
961 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
962 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
963 let q_proj = size_in * size_q / weight_pack_factor;
964 let k_proj = size_in * size_kv / weight_pack_factor;
965 let v_proj = size_in * size_kv / weight_pack_factor;
966 let o_proj = size_q * size_in / weight_pack_factor;
967
968 let h_size = cfg.hidden_size;
969 let i_size = cfg.intermediate_size;
970 let gate_proj = h_size * i_size / weight_pack_factor;
971 let up_proj = h_size * i_size / weight_pack_factor;
972 let down_proj = i_size * h_size / weight_pack_factor;
973
974 input_layernorm
975 + post_attention_layernorm
976 + q_proj
977 + k_proj
978 + v_proj
979 + o_proj
980 + gate_proj
981 + up_proj
982 + down_proj
983 };
984 Ok(vec![
985 per_layer_elems * dtype.size_in_bytes();
986 cfg.num_hidden_layers
987 ])
988 }
989
990 fn num_layers(&self, config: &str) -> Result<usize> {
991 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
992
993 Ok(cfg.num_hidden_layers)
994 }
995 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
996 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
997
998 let cfg = ModelConfigMetadata {
999 max_seq_len: cfg.max_position_embeddings,
1000 num_layers: cfg.num_hidden_layers,
1001 hidden_size: cfg.hidden_size,
1002 num_kv_heads: cfg.num_key_value_heads,
1003 num_attn_heads: cfg.num_attention_heads,
1004 sliding_window: None,
1005 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1006 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1007 };
1008
1009 Ok(Box::new(cfg))
1010 }
1011}
1012
1013pub struct MixtralLoader;
1016
1017impl NormalModelLoader for MixtralLoader {
1018 fn load(
1019 &self,
1020 config: &str,
1021 vb: ShardedVarBuilder,
1022 normal_loading_metadata: NormalLoadingMetadata,
1023 attention_mechanism: AttentionImplementation,
1024 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1025 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1026
1027 Ok(Box::new(models::mixtral::Model::new(
1028 &cfg,
1029 vb,
1030 self.is_gptx(config)?,
1031 normal_loading_metadata,
1032 attention_mechanism,
1033 )?))
1034 }
1035 fn load_xlora(
1036 &self,
1037 config: &str,
1038 vb: ShardedVarBuilder,
1039 lora_config: &[((String, String), LoraConfig)],
1040 xlora_config: Option<XLoraConfig>,
1041 xlora_ordering: Ordering,
1042 normal_loading_metadata: NormalLoadingMetadata,
1043 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1044 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1045 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1046
1047 Ok(Box::new(xlora_models::XLoraMixtral::new(
1048 &cfg,
1049 vb,
1050 lora_config,
1051 xlora_config,
1052 xlora_ordering,
1053 self.is_gptx(config)?,
1054 normal_loading_metadata,
1055 preload_adapters,
1056 )?))
1057 }
1058 fn is_gptx(&self, _: &str) -> Result<bool> {
1059 Ok(true)
1060 }
1061 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1062 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1063
1064 Ok(Box::new(cfg))
1065 }
1066}
1067
1068impl IsqModelLoader for MixtralLoader {
1069 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1070 Ok(vec![
1071 Regex::new(r"lm_head\.(weight|bias)$")?,
1072 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1074 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1075 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1076 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1077 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1079 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1080 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1081 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1082 ])
1083 }
1084 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1085 self.isq_layer_regexes(config)
1086 }
1087}
1088
1089impl DeviceMappedModelLoader for MixtralLoader {
1090 fn mapped_max_act_size_elems(
1091 &self,
1092 config: &str,
1093 params: &AutoDeviceMapParams,
1094 ) -> Result<usize> {
1095 let AutoDeviceMapParams::Text {
1096 max_seq_len,
1097 max_batch_size,
1098 } = params
1099 else {
1100 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1101 };
1102
1103 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1104
1105 Ok(
1106 max_batch_size
1107 * cfg.num_attention_heads
1108 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1109 )
1110 }
1111 fn non_mapped_max_act_size_elems(
1112 &self,
1113 _config: &str,
1114 _params: &AutoDeviceMapParams,
1115 ) -> Result<usize> {
1116 Ok(0)
1117 }
1118
1119 fn non_mapped_size_in_bytes(
1120 &self,
1121 config: &str,
1122 dtype: DType,
1123 weight_pack_factor: usize,
1124 _matformer_config: Option<&MatformerSliceConfig>,
1125 ) -> Result<usize> {
1126 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128 let elems = {
1129 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1130 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1132 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1133 } else {
1134 0
1135 };
1136 let norm = cfg.hidden_size;
1137 embed_tokens + lm_head + norm
1138 };
1139 Ok(elems * dtype.size_in_bytes())
1140 }
1141
1142 fn layer_sizes_in_bytes(
1143 &self,
1144 config: &str,
1145 dtype: DType,
1146 weight_pack_factor: usize,
1147 _matformer_config: Option<&MatformerSliceConfig>,
1148 ) -> Result<Vec<usize>> {
1149 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1150
1151 let per_layer_elems = {
1152 let input_layernorm = cfg.hidden_size;
1153 let post_attention_layernorm = cfg.hidden_size;
1154
1155 let size_in = cfg.hidden_size;
1156 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1157 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1158 let q_proj = size_in * size_q / weight_pack_factor;
1159 let k_proj = size_in * size_kv / weight_pack_factor;
1160 let v_proj = size_in * size_kv / weight_pack_factor;
1161 let o_proj = size_q * size_in / weight_pack_factor;
1162
1163 let moe_block = {
1164 let gate = cfg.hidden_size * cfg.num_local_experts;
1165 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1167 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1168 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1169 gate + cfg.num_local_experts * w1
1170 + cfg.num_local_experts * w2
1171 + cfg.num_local_experts * w3
1172 };
1173
1174 input_layernorm
1175 + post_attention_layernorm
1176 + q_proj
1177 + k_proj
1178 + v_proj
1179 + o_proj
1180 + moe_block
1181 };
1182 Ok(vec![
1183 per_layer_elems * dtype.size_in_bytes();
1184 cfg.num_hidden_layers
1185 ])
1186 }
1187
1188 fn num_layers(&self, config: &str) -> Result<usize> {
1189 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1190
1191 Ok(cfg.num_hidden_layers)
1192 }
1193
1194 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1195 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1196
1197 let cfg = ModelConfigMetadata {
1198 max_seq_len: cfg.max_position_embeddings,
1199 num_layers: cfg.num_hidden_layers,
1200 hidden_size: cfg.hidden_size,
1201 num_kv_heads: cfg.num_key_value_heads,
1202 num_attn_heads: cfg.num_attention_heads,
1203 sliding_window: cfg.sliding_window,
1204 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1205 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1206 };
1207
1208 Ok(Box::new(cfg))
1209 }
1210}
1211
1212pub struct Phi2Loader;
1218
1219impl NormalModelLoader for Phi2Loader {
1220 fn load(
1221 &self,
1222 config: &str,
1223 vb: ShardedVarBuilder,
1224 normal_loading_metadata: NormalLoadingMetadata,
1225 attention_mechanism: AttentionImplementation,
1226 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1227 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1228
1229 Ok(Box::new(models::phi2::Model::new(
1230 &cfg,
1231 vb,
1232 self.is_gptx(config)?,
1233 normal_loading_metadata,
1234 attention_mechanism,
1235 )?))
1236 }
1237 fn load_xlora(
1238 &self,
1239 config: &str,
1240 vb: ShardedVarBuilder,
1241 lora_config: &[((String, String), LoraConfig)],
1242 xlora_config: Option<XLoraConfig>,
1243 xlora_ordering: Ordering,
1244 normal_loading_metadata: NormalLoadingMetadata,
1245 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1246 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1247 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1248
1249 Ok(Box::new(xlora_models::XLoraPhi2::new(
1250 &cfg,
1251 vb,
1252 lora_config,
1253 xlora_config,
1254 xlora_ordering,
1255 self.is_gptx(config)?,
1256 normal_loading_metadata,
1257 preload_adapters,
1258 )?))
1259 }
1260 fn is_gptx(&self, _: &str) -> Result<bool> {
1261 Ok(true)
1262 }
1263 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1264 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1265
1266 Ok(Box::new(cfg))
1267 }
1268}
1269
1270impl IsqModelLoader for Phi2Loader {
1271 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1272 Ok(vec![
1273 Regex::new(r"lm_head\.(weight|bias)$")?,
1274 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1276 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1277 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1278 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1279 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1281 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1282 ])
1283 }
1284 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1285 self.isq_layer_regexes(config)
1286 }
1287}
1288
1289impl DeviceMappedModelLoader for Phi2Loader {
1290 fn mapped_max_act_size_elems(
1291 &self,
1292 config: &str,
1293 params: &AutoDeviceMapParams,
1294 ) -> Result<usize> {
1295 let AutoDeviceMapParams::Text {
1296 max_seq_len,
1297 max_batch_size,
1298 } = params
1299 else {
1300 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1301 };
1302
1303 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1304
1305 Ok(
1306 max_batch_size
1307 * cfg.num_attention_heads
1308 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1309 )
1310 }
1311 fn non_mapped_max_act_size_elems(
1312 &self,
1313 _config: &str,
1314 _params: &AutoDeviceMapParams,
1315 ) -> Result<usize> {
1316 Ok(0)
1317 }
1318
1319 fn non_mapped_size_in_bytes(
1320 &self,
1321 config: &str,
1322 dtype: DType,
1323 weight_pack_factor: usize,
1324 _matformer_config: Option<&MatformerSliceConfig>,
1325 ) -> Result<usize> {
1326 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1327
1328 let elems = {
1329 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1330 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1332 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1333 } else {
1334 0
1335 };
1336 let norm = cfg.hidden_size;
1337 embed_tokens + lm_head + norm
1338 };
1339 Ok(elems * dtype.size_in_bytes())
1340 }
1341
1342 fn layer_sizes_in_bytes(
1343 &self,
1344 config: &str,
1345 dtype: DType,
1346 weight_pack_factor: usize,
1347 _matformer_config: Option<&MatformerSliceConfig>,
1348 ) -> Result<Vec<usize>> {
1349 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1350
1351 let per_layer_elems = {
1352 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1353
1354 let size_in = cfg.hidden_size;
1355 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1356 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1357 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1358 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1359 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1360 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1361 let (q_norm, k_norm) = if cfg.qk_layernorm {
1362 (cfg.head_dim(), cfg.head_dim())
1363 } else {
1364 (0, 0)
1365 };
1366
1367 let h_size = cfg.hidden_size;
1368 let i_size = cfg.intermediate_size;
1369 let fc1 = h_size * i_size / weight_pack_factor;
1370 let fc2 = h_size * i_size / weight_pack_factor;
1371
1372 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1373 };
1374 Ok(vec![
1375 per_layer_elems * dtype.size_in_bytes();
1376 cfg.num_hidden_layers
1377 ])
1378 }
1379
1380 fn num_layers(&self, config: &str) -> Result<usize> {
1381 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1382
1383 Ok(cfg.num_hidden_layers)
1384 }
1385
1386 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1387 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1388
1389 let cfg = ModelConfigMetadata {
1390 max_seq_len: cfg.max_position_embeddings,
1391 num_layers: cfg.num_hidden_layers,
1392 hidden_size: cfg.hidden_size,
1393 num_kv_heads: cfg.num_key_value_heads(),
1394 num_attn_heads: cfg.num_attention_heads,
1395 sliding_window: None,
1396 k_head_dim: cfg.head_dim(),
1397 v_head_dim: cfg.head_dim(),
1398 };
1399
1400 Ok(Box::new(cfg))
1401 }
1402}
1403
1404pub struct Phi3Loader;
1410
1411impl NormalModelLoader for Phi3Loader {
1412 fn load(
1413 &self,
1414 config: &str,
1415 vb: ShardedVarBuilder,
1416 normal_loading_metadata: NormalLoadingMetadata,
1417 attention_mechanism: AttentionImplementation,
1418 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1419 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1420
1421 Ok(Box::new(models::phi3::Model::new(
1422 &cfg,
1423 vb,
1424 self.is_gptx(config)?,
1425 normal_loading_metadata,
1426 attention_mechanism,
1427 )?))
1428 }
1429 fn load_xlora(
1430 &self,
1431 config: &str,
1432 vb: ShardedVarBuilder,
1433 lora_config: &[((String, String), LoraConfig)],
1434 xlora_config: Option<XLoraConfig>,
1435 xlora_ordering: Ordering,
1436 normal_loading_metadata: NormalLoadingMetadata,
1437 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1438 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1439 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1440
1441 Ok(Box::new(xlora_models::XLoraPhi3::new(
1442 &cfg,
1443 vb,
1444 lora_config,
1445 xlora_config,
1446 xlora_ordering,
1447 self.is_gptx(config)?,
1448 normal_loading_metadata,
1449 preload_adapters,
1450 )?))
1451 }
1452 fn is_gptx(&self, _: &str) -> Result<bool> {
1453 Ok(true)
1454 }
1455 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1456 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1457
1458 Ok(Box::new(cfg))
1459 }
1460}
1461
1462impl IsqModelLoader for Phi3Loader {
1463 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1464 Ok(vec![
1465 Regex::new(r"lm_head\.(weight|bias)$")?,
1466 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1468 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1469 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1471 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1472 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1473 ])
1474 }
1475 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1476 self.isq_layer_regexes(config)
1477 }
1478}
1479
1480impl DeviceMappedModelLoader for Phi3Loader {
1481 fn mapped_max_act_size_elems(
1482 &self,
1483 config: &str,
1484 params: &AutoDeviceMapParams,
1485 ) -> Result<usize> {
1486 let AutoDeviceMapParams::Text {
1487 max_seq_len,
1488 max_batch_size,
1489 } = params
1490 else {
1491 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1492 };
1493
1494 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1495
1496 Ok(
1497 max_batch_size
1498 * cfg.num_attention_heads
1499 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1500 )
1501 }
1502 fn non_mapped_max_act_size_elems(
1503 &self,
1504 _config: &str,
1505 _params: &AutoDeviceMapParams,
1506 ) -> Result<usize> {
1507 Ok(0)
1508 }
1509
1510 fn non_mapped_size_in_bytes(
1511 &self,
1512 config: &str,
1513 dtype: DType,
1514 weight_pack_factor: usize,
1515 _matformer_config: Option<&MatformerSliceConfig>,
1516 ) -> Result<usize> {
1517 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1518
1519 let elems = {
1520 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1521 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1523 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1524 } else {
1525 0
1526 };
1527 let norm = cfg.hidden_size;
1528 embed_tokens + lm_head + norm
1529 };
1530 Ok(elems * dtype.size_in_bytes())
1531 }
1532
1533 fn layer_sizes_in_bytes(
1534 &self,
1535 config: &str,
1536 dtype: DType,
1537 weight_pack_factor: usize,
1538 _matformer_config: Option<&MatformerSliceConfig>,
1539 ) -> Result<Vec<usize>> {
1540 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1541
1542 let per_layer_elems = {
1543 let input_layernorm = cfg.hidden_size;
1544 let post_attention_layernorm = cfg.hidden_size;
1545
1546 let size_in = cfg.hidden_size;
1547 let head_dim = cfg.head_dim();
1548 let op_size =
1549 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1550 let qkv_proj = size_in * op_size / weight_pack_factor;
1551 let o_proj =
1552 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1553
1554 let h_size = cfg.hidden_size;
1555 let i_size = cfg.intermediate_size;
1556 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1557 let down_proj = h_size * i_size / weight_pack_factor;
1558
1559 input_layernorm
1560 + post_attention_layernorm
1561 + qkv_proj
1562 + o_proj
1563 + gate_up_proj
1564 + down_proj
1565 };
1566 Ok(vec![
1567 per_layer_elems * dtype.size_in_bytes();
1568 cfg.num_hidden_layers
1569 ])
1570 }
1571
1572 fn num_layers(&self, config: &str) -> Result<usize> {
1573 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1574
1575 Ok(cfg.num_hidden_layers)
1576 }
1577
1578 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1579 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1580
1581 let cfg = ModelConfigMetadata {
1582 max_seq_len: cfg.max_position_embeddings,
1583 num_layers: cfg.num_hidden_layers,
1584 hidden_size: cfg.hidden_size,
1585 num_kv_heads: cfg.num_key_value_heads,
1586 num_attn_heads: cfg.num_attention_heads,
1587 sliding_window: cfg.sliding_window,
1588 k_head_dim: cfg.head_dim(),
1589 v_head_dim: cfg.head_dim(),
1590 };
1591
1592 Ok(Box::new(cfg))
1593 }
1594}
1595
1596pub struct Qwen2Loader;
1602
1603impl NormalModelLoader for Qwen2Loader {
1604 fn load(
1605 &self,
1606 config: &str,
1607 vb: ShardedVarBuilder,
1608 normal_loading_metadata: NormalLoadingMetadata,
1609 attention_mechanism: AttentionImplementation,
1610 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1611 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1612
1613 Ok(Box::new(models::qwen2::Model::new(
1614 &cfg,
1615 vb,
1616 self.is_gptx(config)?,
1617 normal_loading_metadata,
1618 attention_mechanism,
1619 )?))
1620 }
1621 fn load_xlora(
1622 &self,
1623 _config: &str,
1624 _vb: ShardedVarBuilder,
1625 _lora_config: &[((String, String), LoraConfig)],
1626 _xlora_config: Option<XLoraConfig>,
1627 _xlora_ordering: Ordering,
1628 _normal_loading_metadata: NormalLoadingMetadata,
1629 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1630 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1631 todo!()
1632 }
1633 fn is_gptx(&self, _: &str) -> Result<bool> {
1634 Ok(true)
1635 }
1636 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1637 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1638
1639 Ok(Box::new(cfg))
1640 }
1641}
1642
1643impl IsqModelLoader for Qwen2Loader {
1644 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1645 Ok(vec![
1646 Regex::new(r"lm_head\.(weight|bias)$")?,
1647 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1649 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1650 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1651 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1652 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1654 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1655 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1656 ])
1657 }
1658 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1659 self.isq_layer_regexes(config)
1660 }
1661}
1662
1663impl DeviceMappedModelLoader for Qwen2Loader {
1664 fn mapped_max_act_size_elems(
1665 &self,
1666 config: &str,
1667 params: &AutoDeviceMapParams,
1668 ) -> Result<usize> {
1669 let AutoDeviceMapParams::Text {
1670 max_seq_len,
1671 max_batch_size,
1672 } = params
1673 else {
1674 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1675 };
1676
1677 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1678
1679 Ok(
1680 max_batch_size
1681 * cfg.num_attention_heads
1682 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1683 )
1684 }
1685 fn non_mapped_max_act_size_elems(
1686 &self,
1687 _config: &str,
1688 _params: &AutoDeviceMapParams,
1689 ) -> Result<usize> {
1690 Ok(0)
1691 }
1692
1693 fn non_mapped_size_in_bytes(
1694 &self,
1695 config: &str,
1696 dtype: DType,
1697 weight_pack_factor: usize,
1698 _matformer_config: Option<&MatformerSliceConfig>,
1699 ) -> Result<usize> {
1700 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1701
1702 let elems = {
1703 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1704 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1706 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1707 } else {
1708 0
1709 };
1710 let norm = cfg.hidden_size;
1711 embed_tokens + lm_head + norm
1712 };
1713 Ok(elems * dtype.size_in_bytes())
1714 }
1715
1716 fn layer_sizes_in_bytes(
1717 &self,
1718 config: &str,
1719 dtype: DType,
1720 weight_pack_factor: usize,
1721 _matformer_config: Option<&MatformerSliceConfig>,
1722 ) -> Result<Vec<usize>> {
1723 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1724
1725 let per_layer_elems = {
1726 let input_layernorm = cfg.hidden_size;
1727 let post_attention_layernorm = cfg.hidden_size;
1728
1729 let size_in = cfg.hidden_size;
1730 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1731 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1732 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1733 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1734 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1735 let o_proj = size_q * size_in / weight_pack_factor;
1736
1737 let h_size = cfg.hidden_size;
1738 let i_size = cfg.intermediate_size;
1739 let gate_proj = h_size * i_size / weight_pack_factor;
1740 let up_proj = h_size * i_size / weight_pack_factor;
1741 let down_proj = i_size * h_size / weight_pack_factor;
1742
1743 input_layernorm
1744 + post_attention_layernorm
1745 + q_proj
1746 + k_proj
1747 + v_proj
1748 + o_proj
1749 + gate_proj
1750 + up_proj
1751 + down_proj
1752 };
1753 Ok(vec![
1754 per_layer_elems * dtype.size_in_bytes();
1755 cfg.num_hidden_layers
1756 ])
1757 }
1758
1759 fn num_layers(&self, config: &str) -> Result<usize> {
1760 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1761
1762 Ok(cfg.num_hidden_layers)
1763 }
1764
1765 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1766 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1767
1768 let cfg = ModelConfigMetadata {
1769 max_seq_len: cfg.max_position_embeddings,
1770 num_layers: cfg.num_hidden_layers,
1771 hidden_size: cfg.hidden_size,
1772 num_kv_heads: cfg.num_key_value_heads,
1773 num_attn_heads: cfg.num_attention_heads,
1774 sliding_window: cfg.sliding_window,
1775 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1776 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1777 };
1778
1779 Ok(Box::new(cfg))
1780 }
1781}
1782
1783pub struct Gemma2Loader;
1789
1790impl NormalModelLoader for Gemma2Loader {
1791 fn load(
1792 &self,
1793 config: &str,
1794 vb: ShardedVarBuilder,
1795 normal_loading_metadata: NormalLoadingMetadata,
1796 attention_mechanism: AttentionImplementation,
1797 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1798 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1799
1800 Ok(Box::new(models::gemma2::Model::new(
1801 &cfg,
1802 vb,
1803 self.is_gptx(config)?,
1804 normal_loading_metadata,
1805 attention_mechanism,
1806 )?))
1807 }
1808 fn load_xlora(
1809 &self,
1810 config: &str,
1811 vb: ShardedVarBuilder,
1812 lora_config: &[((String, String), LoraConfig)],
1813 xlora_config: Option<XLoraConfig>,
1814 xlora_ordering: Ordering,
1815 normal_loading_metadata: NormalLoadingMetadata,
1816 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1817 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1818 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1819
1820 Ok(Box::new(xlora_models::XLoraGemma2::new(
1821 &cfg,
1822 vb,
1823 lora_config,
1824 xlora_config,
1825 xlora_ordering,
1826 self.is_gptx(config)?,
1827 normal_loading_metadata,
1828 preload_adapters,
1829 )?))
1830 }
1831 fn is_gptx(&self, _: &str) -> Result<bool> {
1832 Ok(true)
1833 }
1834 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1835 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1836
1837 Ok(Box::new(cfg))
1838 }
1839}
1840
1841impl IsqModelLoader for Gemma2Loader {
1842 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1843 Ok(vec![
1844 Regex::new(r"lm_head\.(weight|bias)$")?,
1845 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1847 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1848 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1849 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1850 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1852 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1853 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1854 ])
1855 }
1856 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1857 self.isq_layer_regexes(config)
1858 }
1859}
1860
1861impl DeviceMappedModelLoader for Gemma2Loader {
1862 fn mapped_max_act_size_elems(
1863 &self,
1864 config: &str,
1865 params: &AutoDeviceMapParams,
1866 ) -> Result<usize> {
1867 let AutoDeviceMapParams::Text {
1868 max_seq_len,
1869 max_batch_size,
1870 } = params
1871 else {
1872 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1873 };
1874
1875 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1876
1877 Ok(
1878 max_batch_size
1879 * cfg.num_attention_heads
1880 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1881 )
1882 }
1883 fn non_mapped_max_act_size_elems(
1884 &self,
1885 _config: &str,
1886 _params: &AutoDeviceMapParams,
1887 ) -> Result<usize> {
1888 Ok(0)
1889 }
1890
1891 fn non_mapped_size_in_bytes(
1892 &self,
1893 config: &str,
1894 dtype: DType,
1895 weight_pack_factor: usize,
1896 _matformer_config: Option<&MatformerSliceConfig>,
1897 ) -> Result<usize> {
1898 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1899
1900 let elems = {
1901 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1902 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1904 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1905 } else {
1906 0
1907 };
1908 let norm = cfg.hidden_size;
1909 embed_tokens + lm_head + norm
1910 };
1911 Ok(elems * dtype.size_in_bytes())
1912 }
1913
1914 fn layer_sizes_in_bytes(
1915 &self,
1916 config: &str,
1917 dtype: DType,
1918 weight_pack_factor: usize,
1919 _matformer_config: Option<&MatformerSliceConfig>,
1920 ) -> Result<Vec<usize>> {
1921 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1922
1923 let per_layer_elems = {
1924 let input_layernorm = cfg.hidden_size;
1925 let post_attention_layernorm = cfg.hidden_size;
1926
1927 let size_in = cfg.hidden_size;
1928 let size_q = cfg.head_dim * cfg.num_attention_heads;
1929 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1930 let q_proj =
1931 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1932 let k_proj =
1933 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1934 let v_proj =
1935 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1936 let o_proj =
1937 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1938
1939 let h_size = cfg.hidden_size;
1940 let i_size = cfg.intermediate_size;
1941 let gate_proj = h_size * i_size / weight_pack_factor;
1942 let up_proj = h_size * i_size / weight_pack_factor;
1943 let down_proj = i_size * h_size / weight_pack_factor;
1944
1945 input_layernorm
1946 + post_attention_layernorm
1947 + q_proj
1948 + k_proj
1949 + v_proj
1950 + o_proj
1951 + gate_proj
1952 + up_proj
1953 + down_proj
1954 };
1955 Ok(vec![
1956 per_layer_elems * dtype.size_in_bytes();
1957 cfg.num_hidden_layers
1958 ])
1959 }
1960
1961 fn num_layers(&self, config: &str) -> Result<usize> {
1962 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1963
1964 Ok(cfg.num_hidden_layers)
1965 }
1966 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1967 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1968
1969 let cfg = ModelConfigMetadata {
1970 max_seq_len: cfg.max_position_embeddings,
1971 num_layers: cfg.num_hidden_layers,
1972 hidden_size: cfg.hidden_size,
1973 num_kv_heads: cfg.num_key_value_heads,
1974 num_attn_heads: cfg.num_attention_heads,
1975 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1977 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1978 };
1979
1980 Ok(Box::new(cfg))
1981 }
1982}
1983
1984pub struct Starcoder2Loader;
1990
1991impl NormalModelLoader for Starcoder2Loader {
1992 fn load(
1993 &self,
1994 config: &str,
1995 vb: ShardedVarBuilder,
1996 normal_loading_metadata: NormalLoadingMetadata,
1997 attention_mechanism: AttentionImplementation,
1998 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1999 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2000
2001 Ok(Box::new(models::starcoder2::Model::new(
2002 &cfg,
2003 vb,
2004 self.is_gptx(config)?,
2005 normal_loading_metadata,
2006 attention_mechanism,
2007 )?))
2008 }
2009 fn load_xlora(
2010 &self,
2011 config: &str,
2012 vb: ShardedVarBuilder,
2013 lora_config: &[((String, String), LoraConfig)],
2014 xlora_config: Option<XLoraConfig>,
2015 xlora_ordering: Ordering,
2016 normal_loading_metadata: NormalLoadingMetadata,
2017 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2018 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2019 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2020
2021 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2022 &cfg,
2023 vb,
2024 lora_config,
2025 xlora_config,
2026 xlora_ordering,
2027 self.is_gptx(config)?,
2028 normal_loading_metadata,
2029 preload_adapters,
2030 )?))
2031 }
2032 fn is_gptx(&self, _: &str) -> Result<bool> {
2033 Ok(true)
2034 }
2035 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2036 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2037
2038 Ok(Box::new(cfg))
2039 }
2040}
2041
2042impl IsqModelLoader for Starcoder2Loader {
2043 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2044 Ok(vec![
2045 Regex::new(r"lm_head\.(weight|bias)$")?,
2046 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2048 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2049 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2050 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2051 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2053 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2054 ])
2055 }
2056 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2057 self.isq_layer_regexes(config)
2058 }
2059}
2060
2061impl DeviceMappedModelLoader for Starcoder2Loader {
2062 fn mapped_max_act_size_elems(
2063 &self,
2064 config: &str,
2065 params: &AutoDeviceMapParams,
2066 ) -> Result<usize> {
2067 let AutoDeviceMapParams::Text {
2068 max_seq_len,
2069 max_batch_size,
2070 } = params
2071 else {
2072 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2073 };
2074
2075 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2076
2077 Ok(
2078 max_batch_size
2079 * cfg.num_attention_heads
2080 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2081 )
2082 }
2083 fn non_mapped_max_act_size_elems(
2084 &self,
2085 _config: &str,
2086 _params: &AutoDeviceMapParams,
2087 ) -> Result<usize> {
2088 Ok(0)
2089 }
2090
2091 fn non_mapped_size_in_bytes(
2092 &self,
2093 config: &str,
2094 dtype: DType,
2095 weight_pack_factor: usize,
2096 _matformer_config: Option<&MatformerSliceConfig>,
2097 ) -> Result<usize> {
2098 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2099
2100 let elems = {
2101 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2102 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2104 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2105 } else {
2106 0
2107 };
2108 let norm = cfg.hidden_size + cfg.hidden_size;
2109 embed_tokens + lm_head + norm
2110 };
2111 Ok(elems * dtype.size_in_bytes())
2112 }
2113
2114 fn layer_sizes_in_bytes(
2115 &self,
2116 config: &str,
2117 dtype: DType,
2118 weight_pack_factor: usize,
2119 _matformer_config: Option<&MatformerSliceConfig>,
2120 ) -> Result<Vec<usize>> {
2121 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2122
2123 let per_layer_elems = {
2124 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2125 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2126
2127 let size_in = cfg.hidden_size;
2128 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2129 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2130 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2131 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2132 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2133 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2134
2135 let h_size = cfg.hidden_size;
2136 let i_size = cfg.intermediate_size;
2137 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2138 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2139
2140 input_layernorm
2141 + post_attention_layernorm
2142 + q_proj
2143 + k_proj
2144 + v_proj
2145 + o_proj
2146 + fc1
2147 + fc2
2148 };
2149 Ok(vec![
2150 per_layer_elems * dtype.size_in_bytes();
2151 cfg.num_hidden_layers
2152 ])
2153 }
2154
2155 fn num_layers(&self, config: &str) -> Result<usize> {
2156 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2157
2158 Ok(cfg.num_hidden_layers)
2159 }
2160
2161 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2162 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2163
2164 let cfg = ModelConfigMetadata {
2165 max_seq_len: cfg.max_position_embeddings,
2166 num_layers: cfg.num_hidden_layers,
2167 hidden_size: cfg.hidden_size,
2168 num_kv_heads: cfg.num_key_value_heads,
2169 num_attn_heads: cfg.num_attention_heads,
2170 sliding_window: cfg.sliding_window,
2171 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2172 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2173 };
2174
2175 Ok(Box::new(cfg))
2176 }
2177}
2178
2179pub struct Phi3_5MoELoader;
2185
2186impl NormalModelLoader for Phi3_5MoELoader {
2187 fn load(
2188 &self,
2189 config: &str,
2190 vb: ShardedVarBuilder,
2191 normal_loading_metadata: NormalLoadingMetadata,
2192 attention_mechanism: AttentionImplementation,
2193 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2194 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2195
2196 Ok(Box::new(models::phi3_5_moe::Model::new(
2197 &cfg,
2198 vb,
2199 self.is_gptx(config)?,
2200 normal_loading_metadata,
2201 attention_mechanism,
2202 )?))
2203 }
2204 fn load_xlora(
2205 &self,
2206 config: &str,
2207 vb: ShardedVarBuilder,
2208 lora_config: &[((String, String), LoraConfig)],
2209 xlora_config: Option<XLoraConfig>,
2210 xlora_ordering: Ordering,
2211 normal_loading_metadata: NormalLoadingMetadata,
2212 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2213 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2214 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2215
2216 Ok(Box::new(xlora_models::XLoraPhi3::new(
2217 &cfg,
2218 vb,
2219 lora_config,
2220 xlora_config,
2221 xlora_ordering,
2222 self.is_gptx(config)?,
2223 normal_loading_metadata,
2224 preload_adapters,
2225 )?))
2226 }
2227 fn is_gptx(&self, _: &str) -> Result<bool> {
2228 Ok(true)
2229 }
2230 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2231 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233 Ok(Box::new(cfg))
2234 }
2235}
2236
2237impl IsqModelLoader for Phi3_5MoELoader {
2238 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2239 Ok(vec![
2240 Regex::new(r"lm_head\.(weight|bias)$")?,
2241 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2243 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2244 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2245 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2246 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2248 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2249 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2250 ])
2251 }
2252 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2253 self.isq_layer_regexes(config)
2254 }
2255
2256 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2257 Ok(vec![
2258 Regex::new(r"lm_head\.(weight|bias)$")?,
2259 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2261 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2262 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2263 ])
2264 }
2265 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2266 self.isq_layer_regexes_moqe(config)
2267 }
2268}
2269
2270impl DeviceMappedModelLoader for Phi3_5MoELoader {
2271 fn mapped_max_act_size_elems(
2272 &self,
2273 config: &str,
2274 params: &AutoDeviceMapParams,
2275 ) -> Result<usize> {
2276 let AutoDeviceMapParams::Text {
2277 max_seq_len,
2278 max_batch_size,
2279 } = params
2280 else {
2281 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2282 };
2283
2284 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2285
2286 Ok(
2287 max_batch_size
2288 * cfg.num_attention_heads
2289 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2290 )
2291 }
2292 fn non_mapped_max_act_size_elems(
2293 &self,
2294 _config: &str,
2295 _params: &AutoDeviceMapParams,
2296 ) -> Result<usize> {
2297 Ok(0)
2298 }
2299
2300 fn non_mapped_size_in_bytes(
2301 &self,
2302 config: &str,
2303 dtype: DType,
2304 weight_pack_factor: usize,
2305 _matformer_config: Option<&MatformerSliceConfig>,
2306 ) -> Result<usize> {
2307 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2308
2309 let elems = {
2310 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2311 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2313 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2314 } else {
2315 0
2316 };
2317 let norm = cfg.hidden_size;
2318 embed_tokens + lm_head + norm
2319 };
2320 Ok(elems * dtype.size_in_bytes())
2321 }
2322
2323 fn layer_sizes_in_bytes(
2324 &self,
2325 config: &str,
2326 dtype: DType,
2327 weight_pack_factor: usize,
2328 _matformer_config: Option<&MatformerSliceConfig>,
2329 ) -> Result<Vec<usize>> {
2330 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2331
2332 let per_layer_elems = {
2333 let input_layernorm = cfg.hidden_size;
2334 let post_attention_layernorm = cfg.hidden_size;
2335
2336 let size_in = cfg.hidden_size;
2337 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2338 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2339 let q_proj =
2340 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2341 let k_proj =
2342 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2343 let v_proj =
2344 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2345 let o_proj =
2346 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2347
2348 let moe_block = {
2349 let gate = cfg.hidden_size * cfg.num_local_experts;
2350 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2352 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2353 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2354 gate + cfg.num_local_experts * w1
2355 + cfg.num_local_experts * w2
2356 + cfg.num_local_experts * w3
2357 };
2358
2359 input_layernorm
2360 + post_attention_layernorm
2361 + q_proj
2362 + k_proj
2363 + v_proj
2364 + o_proj
2365 + moe_block
2366 };
2367 Ok(vec![
2368 per_layer_elems * dtype.size_in_bytes();
2369 cfg.num_hidden_layers
2370 ])
2371 }
2372
2373 fn num_layers(&self, config: &str) -> Result<usize> {
2374 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2375
2376 Ok(cfg.num_hidden_layers)
2377 }
2378
2379 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2380 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2381
2382 let cfg = ModelConfigMetadata {
2383 max_seq_len: cfg.max_position_embeddings,
2384 num_layers: cfg.num_hidden_layers,
2385 hidden_size: cfg.hidden_size,
2386 num_kv_heads: cfg.num_key_value_heads,
2387 num_attn_heads: cfg.num_attention_heads,
2388 sliding_window: cfg.sliding_window,
2389 k_head_dim: cfg.head_dim(),
2390 v_head_dim: cfg.head_dim(),
2391 };
2392
2393 Ok(Box::new(cfg))
2394 }
2395}
2396
2397pub struct DeepSeekV2Loader;
2401
2402impl NormalModelLoader for DeepSeekV2Loader {
2403 fn load(
2404 &self,
2405 config: &str,
2406 vb: ShardedVarBuilder,
2407 normal_loading_metadata: NormalLoadingMetadata,
2408 attention_mechanism: AttentionImplementation,
2409 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2410 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2411
2412 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2413 &cfg,
2414 vb,
2415 self.is_gptx(config)?,
2416 normal_loading_metadata,
2417 attention_mechanism,
2418 )?))
2419 }
2420 fn load_xlora(
2421 &self,
2422 _config: &str,
2423 _vb: ShardedVarBuilder,
2424 _lora_config: &[((String, String), LoraConfig)],
2425 _xlora_config: Option<XLoraConfig>,
2426 _xlora_ordering: Ordering,
2427 _normal_loading_metadata: NormalLoadingMetadata,
2428 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2429 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2430 todo!()
2431 }
2432 fn is_gptx(&self, _: &str) -> Result<bool> {
2433 Ok(true)
2434 }
2435 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2436 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2437 Ok(Box::new(cfg))
2438 }
2439}
2440
2441impl IsqModelLoader for DeepSeekV2Loader {
2442 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2443 let mut data = vec![
2444 Regex::new(r"lm_head\.(weight|bias)$")?,
2445 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2447 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2448 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2449 ];
2450 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2451 if cfg.q_lora_rank.is_some() {
2452 data.extend(vec![
2453 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2454 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2455 ]);
2456 } else {
2457 data.push(Regex::new(
2458 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2459 )?);
2460 }
2461 for layer_idx in 0..cfg.num_hidden_layers {
2462 if cfg.n_routed_experts.is_some()
2463 && layer_idx >= cfg.first_k_dense_replace
2464 && layer_idx % cfg.moe_layer_freq == 0
2465 {
2466 for i in 0..cfg.n_routed_experts.unwrap() {
2467 data.extend(vec![
2468 Regex::new(&format!(
2469 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2470 ))?,
2471 Regex::new(&format!(
2472 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2473 ))?,
2474 Regex::new(&format!(
2475 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2476 ))?,
2477 ]);
2478 }
2479 if cfg.n_shared_experts.is_some() {
2480 data.extend(vec![
2481 Regex::new(&format!(
2482 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2483 ))?,
2484 Regex::new(&format!(
2485 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2486 ))?,
2487 Regex::new(&format!(
2488 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2489 ))?,
2490 ]);
2491 }
2492 } else {
2493 data.extend(vec![
2494 Regex::new(&format!(
2495 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2496 ))?,
2497 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2498 Regex::new(&format!(
2499 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2500 ))?,
2501 ]);
2502 };
2503 }
2504 Ok(data)
2505 }
2506 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2507 self.isq_layer_regexes(config)
2508 }
2509
2510 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2511 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2512 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2513 for layer_idx in 0..cfg.num_hidden_layers {
2514 if cfg.n_routed_experts.is_some()
2515 && layer_idx >= cfg.first_k_dense_replace
2516 && layer_idx % cfg.moe_layer_freq == 0
2517 {
2518 for i in 0..cfg.n_routed_experts.unwrap() {
2519 data.extend(vec![
2520 Regex::new(&format!(
2521 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2522 ))?,
2523 Regex::new(&format!(
2524 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2525 ))?,
2526 Regex::new(&format!(
2527 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2528 ))?,
2529 ]);
2530 }
2531 if cfg.n_shared_experts.is_some() {
2532 data.extend(vec![
2533 Regex::new(&format!(
2534 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2535 ))?,
2536 Regex::new(&format!(
2537 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2538 ))?,
2539 Regex::new(&format!(
2540 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2541 ))?,
2542 ]);
2543 }
2544 } else {
2545 data.extend(vec![
2546 Regex::new(&format!(
2547 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2548 ))?,
2549 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2550 Regex::new(&format!(
2551 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2552 ))?,
2553 ]);
2554 };
2555 }
2556 Ok(data)
2557 }
2558 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2559 self.isq_layer_regexes_moqe(config)
2560 }
2561}
2562
2563impl DeviceMappedModelLoader for DeepSeekV2Loader {
2564 fn mapped_max_act_size_elems(
2565 &self,
2566 config: &str,
2567 params: &AutoDeviceMapParams,
2568 ) -> Result<usize> {
2569 let AutoDeviceMapParams::Text {
2570 max_seq_len,
2571 max_batch_size,
2572 } = params
2573 else {
2574 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2575 };
2576
2577 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2578
2579 Ok(
2580 max_batch_size
2581 * cfg.num_attention_heads
2582 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2583 )
2584 }
2585 fn non_mapped_max_act_size_elems(
2586 &self,
2587 _config: &str,
2588 _params: &AutoDeviceMapParams,
2589 ) -> Result<usize> {
2590 Ok(0)
2591 }
2592
2593 fn non_mapped_size_in_bytes(
2594 &self,
2595 config: &str,
2596 dtype: DType,
2597 weight_pack_factor: usize,
2598 _matformer_config: Option<&MatformerSliceConfig>,
2599 ) -> Result<usize> {
2600 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2601 let elems = {
2602 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2603 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2605 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2606 } else {
2607 0
2608 };
2609 let norm = cfg.hidden_size;
2610 embed_tokens + lm_head + norm
2611 };
2612 Ok(elems * dtype.size_in_bytes())
2613 }
2614
2615 fn layer_sizes_in_bytes(
2616 &self,
2617 config: &str,
2618 dtype: DType,
2619 weight_pack_factor: usize,
2620 _matformer_config: Option<&MatformerSliceConfig>,
2621 ) -> Result<Vec<usize>> {
2622 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2623 let mut per_layer_elems = Vec::new();
2624
2625 for layer_idx in 0..cfg.num_hidden_layers {
2626 let input_layernorm = cfg.hidden_size;
2627 let post_attention_layernorm = cfg.hidden_size;
2628
2629 let q_proj = match cfg.q_lora_rank {
2630 Some(lora_rank) => {
2631 let a = cfg.hidden_size * lora_rank;
2632 let norm = lora_rank;
2633 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2634 a + norm + b
2635 }
2636 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2637 };
2638 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2639 / weight_pack_factor
2640 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2641 let kv_a_layernorm = cfg.kv_lora_rank;
2642 let kv_b_proj = cfg.kv_lora_rank
2643 * cfg.num_attention_heads
2644 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2645 / weight_pack_factor;
2646 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2647 / weight_pack_factor
2648 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2649
2650 let moe_block = {
2651 let mut sum = 0;
2652 if cfg.n_routed_experts.is_some()
2653 && layer_idx >= cfg.first_k_dense_replace
2654 && layer_idx % cfg.moe_layer_freq == 0
2655 {
2656 let h_size = cfg.hidden_size;
2657 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2658 * cfg.n_routed_experts.unwrap();
2659 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2660 * cfg.n_routed_experts.unwrap();
2661 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2662 * cfg.n_routed_experts.unwrap();
2663 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2664 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2665 / weight_pack_factor;
2666 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2667 / weight_pack_factor;
2668 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2669 / weight_pack_factor;
2670 gate_proj + up_proj + down_proj
2671 } else {
2672 0
2673 };
2674 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2675 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2676 } else {
2677 let h_size = cfg.hidden_size;
2678 let i_size = cfg.intermediate_size;
2679 let gate_proj = h_size * i_size / weight_pack_factor;
2680 let up_proj = h_size * i_size / weight_pack_factor;
2681 let down_proj = i_size * h_size / weight_pack_factor;
2682 sum += gate_proj + up_proj + down_proj;
2683 }
2684 sum
2685 };
2686
2687 per_layer_elems.push(
2688 input_layernorm
2689 + post_attention_layernorm
2690 + q_proj
2691 + kv_a_layernorm
2692 + kv_a_proj_with_mqa
2693 + kv_b_proj
2694 + o_proj
2695 + moe_block,
2696 );
2697 }
2698
2699 Ok(per_layer_elems
2700 .into_iter()
2701 .map(|x| x * dtype.size_in_bytes())
2702 .collect())
2703 }
2704
2705 fn num_layers(&self, config: &str) -> Result<usize> {
2706 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2707 Ok(cfg.num_hidden_layers)
2708 }
2709
2710 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2711 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2712
2713 let cfg = ModelConfigMetadata {
2714 max_seq_len: cfg.max_position_embeddings,
2715 num_layers: cfg.num_hidden_layers,
2716 hidden_size: cfg.hidden_size,
2717 num_kv_heads: cfg.num_attention_heads,
2718 num_attn_heads: cfg.num_attention_heads,
2719 sliding_window: None,
2720 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2721 v_head_dim: cfg.v_head_dim,
2722 };
2723
2724 Ok(Box::new(cfg))
2725 }
2726}
2727
2728pub struct DeepSeekV3Loader;
2732
2733impl NormalModelLoader for DeepSeekV3Loader {
2734 fn load(
2735 &self,
2736 config: &str,
2737 vb: ShardedVarBuilder,
2738 normal_loading_metadata: NormalLoadingMetadata,
2739 attention_mechanism: AttentionImplementation,
2740 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2741 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2742 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2743 &cfg,
2744 vb,
2745 self.is_gptx(config)?,
2746 normal_loading_metadata,
2747 attention_mechanism,
2748 )?))
2749 }
2750 fn load_xlora(
2751 &self,
2752 _config: &str,
2753 _vb: ShardedVarBuilder,
2754 _lora_config: &[((String, String), LoraConfig)],
2755 _xlora_config: Option<XLoraConfig>,
2756 _xlora_ordering: Ordering,
2757 _normal_loading_metadata: NormalLoadingMetadata,
2758 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2759 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2760 todo!()
2761 }
2762 fn is_gptx(&self, _: &str) -> Result<bool> {
2763 Ok(true)
2764 }
2765 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2766 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2767 Ok(Box::new(cfg))
2768 }
2769}
2770
2771impl IsqModelLoader for DeepSeekV3Loader {
2772 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2773 let mut data = vec![
2774 Regex::new(r"lm_head\.(weight|bias)$")?,
2775 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2777 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2778 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2779 ];
2780 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2781 if cfg.q_lora_rank.is_some() {
2782 data.extend(vec![
2783 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2784 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2785 ]);
2786 } else {
2787 data.push(Regex::new(
2788 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2789 )?);
2790 }
2791 for layer_idx in 0..cfg.num_hidden_layers {
2792 if cfg.n_routed_experts.is_some()
2793 && layer_idx >= cfg.first_k_dense_replace
2794 && layer_idx % cfg.moe_layer_freq == 0
2795 {
2796 for i in 0..cfg.n_routed_experts.unwrap() {
2797 data.extend(vec![
2798 Regex::new(&format!(
2799 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2800 ))?,
2801 Regex::new(&format!(
2802 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2803 ))?,
2804 Regex::new(&format!(
2805 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2806 ))?,
2807 ]);
2808 }
2809 if cfg.n_shared_experts.is_some() {
2810 data.extend(vec![
2811 Regex::new(&format!(
2812 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2813 ))?,
2814 Regex::new(&format!(
2815 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2816 ))?,
2817 Regex::new(&format!(
2818 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2819 ))?,
2820 ]);
2821 }
2822 } else {
2823 data.extend(vec![
2824 Regex::new(&format!(
2825 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2826 ))?,
2827 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2828 Regex::new(&format!(
2829 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2830 ))?,
2831 ]);
2832 };
2833 }
2834 Ok(data)
2835 }
2836 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2837 self.isq_layer_regexes(config)
2838 }
2839
2840 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2841 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2842 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2843 for layer_idx in 0..cfg.num_hidden_layers {
2844 if cfg.n_routed_experts.is_some()
2845 && layer_idx >= cfg.first_k_dense_replace
2846 && layer_idx % cfg.moe_layer_freq == 0
2847 {
2848 for i in 0..cfg.n_routed_experts.unwrap() {
2849 data.extend(vec![
2850 Regex::new(&format!(
2851 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2852 ))?,
2853 Regex::new(&format!(
2854 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2855 ))?,
2856 Regex::new(&format!(
2857 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2858 ))?,
2859 ]);
2860 }
2861 if cfg.n_shared_experts.is_some() {
2862 data.extend(vec![
2863 Regex::new(&format!(
2864 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2865 ))?,
2866 Regex::new(&format!(
2867 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2868 ))?,
2869 Regex::new(&format!(
2870 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2871 ))?,
2872 ]);
2873 }
2874 } else {
2875 data.extend(vec![
2876 Regex::new(&format!(
2877 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2878 ))?,
2879 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2880 Regex::new(&format!(
2881 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2882 ))?,
2883 ]);
2884 };
2885 }
2886 Ok(data)
2887 }
2888 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2889 self.isq_layer_regexes_moqe(config)
2890 }
2891}
2892
2893impl DeviceMappedModelLoader for DeepSeekV3Loader {
2894 fn mapped_max_act_size_elems(
2895 &self,
2896 config: &str,
2897 params: &AutoDeviceMapParams,
2898 ) -> Result<usize> {
2899 let AutoDeviceMapParams::Text {
2900 max_seq_len,
2901 max_batch_size,
2902 } = params
2903 else {
2904 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2905 };
2906
2907 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2908
2909 Ok(
2910 max_batch_size
2911 * cfg.num_attention_heads
2912 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2913 )
2914 }
2915 fn non_mapped_max_act_size_elems(
2916 &self,
2917 _config: &str,
2918 _params: &AutoDeviceMapParams,
2919 ) -> Result<usize> {
2920 Ok(0)
2921 }
2922
2923 fn non_mapped_size_in_bytes(
2924 &self,
2925 config: &str,
2926 dtype: DType,
2927 weight_pack_factor: usize,
2928 _matformer_config: Option<&MatformerSliceConfig>,
2929 ) -> Result<usize> {
2930 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2931 let elems = {
2932 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2933 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2935 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2936 } else {
2937 0
2938 };
2939 let norm = cfg.hidden_size;
2940 embed_tokens + lm_head + norm
2941 };
2942 Ok(elems * dtype.size_in_bytes())
2943 }
2944
2945 fn layer_sizes_in_bytes(
2946 &self,
2947 config: &str,
2948 dtype: DType,
2949 weight_pack_factor: usize,
2950 _matformer_config: Option<&MatformerSliceConfig>,
2951 ) -> Result<Vec<usize>> {
2952 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2953 let mut per_layer_elems = Vec::new();
2954
2955 for layer_idx in 0..cfg.num_hidden_layers {
2956 let input_layernorm = cfg.hidden_size;
2957 let post_attention_layernorm = cfg.hidden_size;
2958
2959 let q_proj = match cfg.q_lora_rank {
2960 Some(lora_rank) => {
2961 let a = cfg.hidden_size * lora_rank;
2962 let norm = lora_rank;
2963 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2964 a + norm + b
2965 }
2966 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2967 };
2968 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2969 / weight_pack_factor
2970 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2971 let kv_a_layernorm = cfg.kv_lora_rank;
2972 let kv_b_proj = cfg.kv_lora_rank
2973 * cfg.num_attention_heads
2974 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2975 / weight_pack_factor;
2976 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2977 / weight_pack_factor
2978 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2979
2980 let moe_block = {
2981 let mut sum = 0;
2982 if cfg.n_routed_experts.is_some()
2983 && layer_idx >= cfg.first_k_dense_replace
2984 && layer_idx % cfg.moe_layer_freq == 0
2985 {
2986 let h_size = cfg.hidden_size;
2987 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2988 * cfg.n_routed_experts.unwrap();
2989 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2990 * cfg.n_routed_experts.unwrap();
2991 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2992 * cfg.n_routed_experts.unwrap();
2993 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2994 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2995 / weight_pack_factor;
2996 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2997 / weight_pack_factor;
2998 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2999 / weight_pack_factor;
3000 gate_proj + up_proj + down_proj
3001 } else {
3002 0
3003 };
3004 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3005 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3006 } else {
3007 let h_size = cfg.hidden_size;
3008 let i_size = cfg.intermediate_size;
3009 let gate_proj = h_size * i_size / weight_pack_factor;
3010 let up_proj = h_size * i_size / weight_pack_factor;
3011 let down_proj = i_size * h_size / weight_pack_factor;
3012 sum += gate_proj + up_proj + down_proj;
3013 }
3014 sum
3015 };
3016
3017 per_layer_elems.push(
3018 input_layernorm
3019 + post_attention_layernorm
3020 + q_proj
3021 + kv_a_layernorm
3022 + kv_a_proj_with_mqa
3023 + kv_b_proj
3024 + o_proj
3025 + moe_block,
3026 );
3027 }
3028
3029 Ok(per_layer_elems
3030 .into_iter()
3031 .map(|x| x * dtype.size_in_bytes())
3032 .collect())
3033 }
3034
3035 fn num_layers(&self, config: &str) -> Result<usize> {
3036 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3037 Ok(cfg.num_hidden_layers)
3038 }
3039
3040 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3041 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3042
3043 let cfg = ModelConfigMetadata {
3044 max_seq_len: cfg.max_position_embeddings,
3045 num_layers: cfg.num_hidden_layers,
3046 hidden_size: cfg.hidden_size,
3047 num_kv_heads: cfg.num_attention_heads,
3048 num_attn_heads: cfg.num_attention_heads,
3049 sliding_window: None,
3050 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3051 v_head_dim: cfg.v_head_dim,
3052 };
3053
3054 Ok(Box::new(cfg))
3055 }
3056}
3057
3058pub struct Qwen3Loader;
3062
3063impl NormalModelLoader for Qwen3Loader {
3064 fn load(
3065 &self,
3066 config: &str,
3067 vb: ShardedVarBuilder,
3068 normal_loading_metadata: NormalLoadingMetadata,
3069 attention_mechanism: AttentionImplementation,
3070 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3071 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3072
3073 Ok(Box::new(models::qwen3::Model::new(
3074 &cfg,
3075 vb,
3076 self.is_gptx(config)?,
3077 normal_loading_metadata,
3078 attention_mechanism,
3079 )?))
3080 }
3081 fn load_xlora(
3082 &self,
3083 _config: &str,
3084 _vb: ShardedVarBuilder,
3085 _lora_config: &[((String, String), LoraConfig)],
3086 _xlora_config: Option<XLoraConfig>,
3087 _xlora_ordering: Ordering,
3088 _normal_loading_metadata: NormalLoadingMetadata,
3089 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3090 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3091 todo!()
3092 }
3093 fn is_gptx(&self, _: &str) -> Result<bool> {
3094 Ok(true)
3095 }
3096 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3097 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3098
3099 Ok(Box::new(cfg))
3100 }
3101}
3102
3103impl IsqModelLoader for Qwen3Loader {
3104 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3105 Ok(vec![
3106 Regex::new(r"lm_head\.(weight|bias)$")?,
3107 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3109 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3110 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3111 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3112 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3114 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3115 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3116 ])
3117 }
3118 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3119 self.isq_layer_regexes(config)
3120 }
3121}
3122
3123impl DeviceMappedModelLoader for Qwen3Loader {
3124 fn mapped_max_act_size_elems(
3125 &self,
3126 config: &str,
3127 params: &AutoDeviceMapParams,
3128 ) -> Result<usize> {
3129 let AutoDeviceMapParams::Text {
3130 max_seq_len,
3131 max_batch_size,
3132 } = params
3133 else {
3134 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3135 };
3136
3137 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3138
3139 Ok(
3140 max_batch_size
3141 * cfg.num_attention_heads
3142 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3143 )
3144 }
3145 fn non_mapped_max_act_size_elems(
3146 &self,
3147 _config: &str,
3148 _params: &AutoDeviceMapParams,
3149 ) -> Result<usize> {
3150 Ok(0)
3151 }
3152
3153 fn non_mapped_size_in_bytes(
3154 &self,
3155 config: &str,
3156 dtype: DType,
3157 weight_pack_factor: usize,
3158 _matformer_config: Option<&MatformerSliceConfig>,
3159 ) -> Result<usize> {
3160 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3161 let elems = {
3162 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3163 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3165 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3166 } else {
3167 0
3168 };
3169 let norm = cfg.hidden_size;
3170 embed_tokens + lm_head + norm
3171 };
3172 Ok(elems * dtype.size_in_bytes())
3173 }
3174
3175 fn layer_sizes_in_bytes(
3176 &self,
3177 config: &str,
3178 dtype: DType,
3179 weight_pack_factor: usize,
3180 _matformer_config: Option<&MatformerSliceConfig>,
3181 ) -> Result<Vec<usize>> {
3182 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3183 let per_layer_elems = {
3184 let input_layernorm = cfg.hidden_size;
3185 let post_attention_layernorm = cfg.hidden_size;
3186
3187 let size_in = cfg.hidden_size;
3188 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3189 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3190 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3191 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3192 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3193 let o_proj = size_q * size_in / weight_pack_factor;
3194
3195 let h_size = cfg.hidden_size;
3196 let i_size = cfg.intermediate_size;
3197 let gate_proj = h_size * i_size / weight_pack_factor;
3198 let up_proj = h_size * i_size / weight_pack_factor;
3199 let down_proj = i_size * h_size / weight_pack_factor;
3200
3201 let q_norm = cfg.head_dim();
3202 let k_norm = cfg.head_dim();
3203
3204 input_layernorm
3205 + post_attention_layernorm
3206 + q_proj
3207 + k_proj
3208 + v_proj
3209 + o_proj
3210 + gate_proj
3211 + up_proj
3212 + down_proj
3213 + q_norm
3214 + k_norm
3215 };
3216 Ok(vec![
3217 per_layer_elems * dtype.size_in_bytes();
3218 cfg.num_hidden_layers
3219 ])
3220 }
3221
3222 fn num_layers(&self, config: &str) -> Result<usize> {
3223 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3224 Ok(cfg.num_hidden_layers)
3225 }
3226
3227 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3228 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3229
3230 let cfg = ModelConfigMetadata {
3231 max_seq_len: cfg.max_position_embeddings,
3232 num_layers: cfg.num_hidden_layers,
3233 hidden_size: cfg.hidden_size,
3234 num_kv_heads: cfg.num_key_value_heads,
3235 num_attn_heads: cfg.num_attention_heads,
3236 sliding_window: cfg.sliding_window,
3237 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3238 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3239 };
3240
3241 Ok(Box::new(cfg))
3242 }
3243}
3244
3245pub struct GLM4Loader;
3249
3250impl NormalModelLoader for GLM4Loader {
3251 fn load(
3252 &self,
3253 config: &str,
3254 vb: ShardedVarBuilder,
3255 normal_loading_metadata: NormalLoadingMetadata,
3256 attention_mechanism: AttentionImplementation,
3257 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3258 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3259
3260 Ok(Box::new(models::glm4::Model::new(
3261 &cfg,
3262 vb,
3263 self.is_gptx(config)?,
3264 normal_loading_metadata,
3265 attention_mechanism,
3266 )?))
3267 }
3268 fn load_xlora(
3269 &self,
3270 _config: &str,
3271 _vb: ShardedVarBuilder,
3272 _lora_config: &[((String, String), LoraConfig)],
3273 _xlora_config: Option<XLoraConfig>,
3274 _xlora_ordering: Ordering,
3275 _normal_loading_metadata: NormalLoadingMetadata,
3276 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3277 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3278 todo!()
3279 }
3280 fn is_gptx(&self, _: &str) -> Result<bool> {
3281 Ok(true)
3282 }
3283 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3284 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3285
3286 Ok(Box::new(cfg))
3287 }
3288}
3289
3290impl IsqModelLoader for GLM4Loader {
3291 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3292 Ok(vec![
3293 Regex::new(r"lm_head\.(weight|bias)$")?,
3294 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3296 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3297 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3298 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3299 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3301 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3302 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3303 ])
3304 }
3305 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3306 self.isq_layer_regexes(config)
3307 }
3308}
3309
3310impl DeviceMappedModelLoader for GLM4Loader {
3311 fn mapped_max_act_size_elems(
3312 &self,
3313 config: &str,
3314 params: &AutoDeviceMapParams,
3315 ) -> Result<usize> {
3316 let AutoDeviceMapParams::Text {
3317 max_seq_len,
3318 max_batch_size,
3319 } = params
3320 else {
3321 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3322 };
3323
3324 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3325
3326 Ok(
3327 max_batch_size
3328 * cfg.num_attention_heads
3329 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3330 )
3331 }
3332 fn non_mapped_max_act_size_elems(
3333 &self,
3334 _config: &str,
3335 _params: &AutoDeviceMapParams,
3336 ) -> Result<usize> {
3337 Ok(0)
3338 }
3339
3340 fn non_mapped_size_in_bytes(
3341 &self,
3342 config: &str,
3343 dtype: DType,
3344 weight_pack_factor: usize,
3345 _matformer_config: Option<&MatformerSliceConfig>,
3346 ) -> Result<usize> {
3347 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3348 let elems = {
3349 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3350 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3352 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3353 } else {
3354 0
3355 };
3356 let norm = cfg.hidden_size;
3357 embed_tokens + lm_head + norm
3358 };
3359 Ok(elems * dtype.size_in_bytes())
3360 }
3361
3362 fn layer_sizes_in_bytes(
3363 &self,
3364 config: &str,
3365 dtype: DType,
3366 weight_pack_factor: usize,
3367 _matformer_config: Option<&MatformerSliceConfig>,
3368 ) -> Result<Vec<usize>> {
3369 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3370 let per_layer_elems = {
3371 let input_layernorm = cfg.hidden_size;
3372 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3375 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3376 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3377 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3378 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3379 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3380 let o_proj = size_q * size_in / weight_pack_factor;
3381
3382 let h_size = cfg.hidden_size;
3383 let i_size = cfg.intermediate_size;
3384 let gate_proj = h_size * i_size / weight_pack_factor;
3385 let up_proj = h_size * i_size / weight_pack_factor;
3386 let down_proj = i_size * h_size / weight_pack_factor;
3387
3388 input_layernorm
3389 + post_attention_layernorm
3390 + q_proj
3391 + k_proj
3392 + v_proj
3393 + o_proj
3394 + gate_proj
3395 + up_proj
3396 + down_proj
3397 };
3398 Ok(vec![
3399 per_layer_elems * dtype.size_in_bytes();
3400 cfg.num_hidden_layers
3401 ])
3402 }
3403
3404 fn num_layers(&self, config: &str) -> Result<usize> {
3405 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3406 Ok(cfg.num_hidden_layers)
3407 }
3408
3409 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3410 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3411
3412 let cfg = ModelConfigMetadata {
3413 max_seq_len: cfg.max_position_embeddings,
3414 num_layers: cfg.num_hidden_layers,
3415 hidden_size: cfg.hidden_size,
3416 num_kv_heads: cfg.num_key_value_heads,
3417 num_attn_heads: cfg.num_attention_heads,
3418 sliding_window: cfg.sliding_window,
3419 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3420 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3421 };
3422
3423 Ok(Box::new(cfg))
3424 }
3425}
3426
3427pub struct Qwen3MoELoader;
3431
3432impl NormalModelLoader for Qwen3MoELoader {
3433 fn load(
3434 &self,
3435 config: &str,
3436 vb: ShardedVarBuilder,
3437 normal_loading_metadata: NormalLoadingMetadata,
3438 attention_mechanism: AttentionImplementation,
3439 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3440 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3441
3442 Ok(Box::new(models::qwen3_moe::Model::new(
3443 &cfg,
3444 vb,
3445 self.is_gptx(config)?,
3446 normal_loading_metadata,
3447 attention_mechanism,
3448 )?))
3449 }
3450 fn load_xlora(
3451 &self,
3452 _config: &str,
3453 _vb: ShardedVarBuilder,
3454 _lora_config: &[((String, String), LoraConfig)],
3455 _xlora_config: Option<XLoraConfig>,
3456 _xlora_ordering: Ordering,
3457 _normal_loading_metadata: NormalLoadingMetadata,
3458 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3459 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3460 todo!()
3461 }
3462 fn is_gptx(&self, _: &str) -> Result<bool> {
3463 Ok(true)
3464 }
3465 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3466 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3467
3468 Ok(Box::new(cfg))
3469 }
3470}
3471
3472impl IsqModelLoader for Qwen3MoELoader {
3473 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3474 Ok(vec![
3475 Regex::new(r"lm_head\.(weight|bias)$")?,
3476 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3478 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3479 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3480 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3481 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3483 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3484 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3485 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3487 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3488 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3489 ])
3490 }
3491 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3492 self.isq_layer_regexes(config)
3493 }
3494 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3495 self.isq_layer_regexes_moqe(config)
3496 }
3497}
3498
3499impl DeviceMappedModelLoader for Qwen3MoELoader {
3500 fn mapped_max_act_size_elems(
3501 &self,
3502 config: &str,
3503 params: &AutoDeviceMapParams,
3504 ) -> Result<usize> {
3505 let AutoDeviceMapParams::Text {
3506 max_seq_len,
3507 max_batch_size,
3508 } = params
3509 else {
3510 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3511 };
3512
3513 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3514
3515 Ok(
3516 max_batch_size
3517 * cfg.num_attention_heads
3518 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3519 )
3520 }
3521 fn non_mapped_max_act_size_elems(
3522 &self,
3523 _config: &str,
3524 _params: &AutoDeviceMapParams,
3525 ) -> Result<usize> {
3526 Ok(0)
3527 }
3528
3529 fn non_mapped_size_in_bytes(
3530 &self,
3531 config: &str,
3532 dtype: DType,
3533 weight_pack_factor: usize,
3534 _matformer_config: Option<&MatformerSliceConfig>,
3535 ) -> Result<usize> {
3536 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3537 let elems = {
3538 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3539 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3541 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3542 } else {
3543 0
3544 };
3545 let norm = cfg.hidden_size;
3546 embed_tokens + lm_head + norm
3547 };
3548 Ok(elems * dtype.size_in_bytes())
3549 }
3550
3551 fn layer_sizes_in_bytes(
3552 &self,
3553 config: &str,
3554 dtype: DType,
3555 weight_pack_factor: usize,
3556 _matformer_config: Option<&MatformerSliceConfig>,
3557 ) -> Result<Vec<usize>> {
3558 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3559
3560 let mut layer_sizes_in_bytes = Vec::new();
3561 for layer_idx in 0..cfg.num_hidden_layers {
3562 let input_layernorm = cfg.hidden_size;
3563 let post_attention_layernorm = cfg.hidden_size;
3564
3565 let size_in = cfg.hidden_size;
3566 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3567 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3568 let q_proj = size_in * size_q / weight_pack_factor;
3569 let k_proj = size_in * size_kv / weight_pack_factor;
3570 let v_proj = size_in * size_kv / weight_pack_factor;
3571 let o_proj = size_q * size_in / weight_pack_factor;
3572
3573 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3574 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3575 {
3576 let gate_size = cfg.hidden_size * cfg.num_experts;
3577 let expert_size = {
3578 let h_size = cfg.hidden_size;
3579 let i_size = cfg.moe_intermediate_size;
3580 let gate_proj = h_size * i_size / weight_pack_factor;
3581 let up_proj = h_size * i_size / weight_pack_factor;
3582 let down_proj = i_size * h_size / weight_pack_factor;
3583 gate_proj + up_proj + down_proj
3584 };
3585 expert_size * cfg.num_experts + gate_size
3586 } else {
3587 let h_size = cfg.hidden_size;
3588 let i_size = cfg.intermediate_size;
3589 let gate_proj = h_size * i_size / weight_pack_factor;
3590 let up_proj = h_size * i_size / weight_pack_factor;
3591 let down_proj = i_size * h_size / weight_pack_factor;
3592 gate_proj + up_proj + down_proj
3593 };
3594
3595 let q_norm = cfg.head_dim();
3596 let k_norm = cfg.head_dim();
3597
3598 let size_elems = input_layernorm
3599 + post_attention_layernorm
3600 + q_proj
3601 + k_proj
3602 + v_proj
3603 + o_proj
3604 + mlp_size
3605 + q_norm
3606 + k_norm;
3607
3608 let size_in_bytes = size_elems * dtype.size_in_bytes();
3609 layer_sizes_in_bytes.push(size_in_bytes);
3610 }
3611
3612 Ok(layer_sizes_in_bytes)
3613 }
3614
3615 fn num_layers(&self, config: &str) -> Result<usize> {
3616 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3617 Ok(cfg.num_hidden_layers)
3618 }
3619
3620 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3621 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3622
3623 let cfg = ModelConfigMetadata {
3624 max_seq_len: cfg.max_position_embeddings,
3625 num_layers: cfg.num_hidden_layers,
3626 hidden_size: cfg.hidden_size,
3627 num_kv_heads: cfg.num_key_value_heads,
3628 num_attn_heads: cfg.num_attention_heads,
3629 sliding_window: cfg.sliding_window,
3630 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3631 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3632 };
3633
3634 Ok(Box::new(cfg))
3635 }
3636}
3637
3638pub struct SmolLm3Loader;
3644
3645impl NormalModelLoader for SmolLm3Loader {
3646 fn load(
3647 &self,
3648 config: &str,
3649 vb: ShardedVarBuilder,
3650 normal_loading_metadata: NormalLoadingMetadata,
3651 attention_mechanism: AttentionImplementation,
3652 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3653 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3654
3655 Ok(Box::new(models::smollm3::SmolLm3::new(
3656 &cfg,
3657 vb,
3658 self.is_gptx(config)?,
3659 normal_loading_metadata,
3660 attention_mechanism,
3661 )?))
3662 }
3663 fn load_xlora(
3664 &self,
3665 _config: &str,
3666 _vb: ShardedVarBuilder,
3667 _lora_config: &[((String, String), LoraConfig)],
3668 _xlora_config: Option<XLoraConfig>,
3669 _xlora_ordering: Ordering,
3670 _normal_loading_metadata: NormalLoadingMetadata,
3671 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3672 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3673 todo!()
3674 }
3675 fn is_gptx(&self, _: &str) -> Result<bool> {
3676 Ok(true)
3677 }
3678 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3679 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3680 Ok(Box::new(cfg))
3681 }
3682}
3683
3684impl IsqModelLoader for SmolLm3Loader {
3685 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3686 Ok(vec![
3687 Regex::new(r"lm_head\.(weight|bias)$")?,
3688 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3690 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3691 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3692 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3693 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3695 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3696 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3697 ])
3698 }
3699 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3700 self.isq_layer_regexes(config)
3701 }
3702}
3703
3704impl DeviceMappedModelLoader for SmolLm3Loader {
3705 fn mapped_max_act_size_elems(
3706 &self,
3707 config: &str,
3708 params: &AutoDeviceMapParams,
3709 ) -> Result<usize> {
3710 let AutoDeviceMapParams::Text {
3711 max_seq_len,
3712 max_batch_size,
3713 } = params
3714 else {
3715 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3716 };
3717
3718 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3719
3720 Ok(
3721 max_batch_size
3722 * cfg.num_attention_heads
3723 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3724 )
3725 }
3726 fn non_mapped_max_act_size_elems(
3727 &self,
3728 _config: &str,
3729 _params: &AutoDeviceMapParams,
3730 ) -> Result<usize> {
3731 Ok(0)
3732 }
3733
3734 fn non_mapped_size_in_bytes(
3735 &self,
3736 config: &str,
3737 dtype: DType,
3738 weight_pack_factor: usize,
3739 _matformer_config: Option<&MatformerSliceConfig>,
3740 ) -> Result<usize> {
3741 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3742
3743 let elems = {
3744 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3745 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3747 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3748 } else {
3749 0
3750 };
3751 let norm = cfg.hidden_size;
3752 embed_tokens + lm_head + norm
3753 };
3754 Ok(elems * dtype.size_in_bytes())
3755 }
3756
3757 fn layer_sizes_in_bytes(
3758 &self,
3759 config: &str,
3760 dtype: DType,
3761 weight_pack_factor: usize,
3762 _matformer_config: Option<&MatformerSliceConfig>,
3763 ) -> Result<Vec<usize>> {
3764 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3765
3766 let per_layer_elems = {
3767 let input_layernorm = cfg.hidden_size;
3768 let post_attention_layernorm = cfg.hidden_size;
3769
3770 let size_in = cfg.hidden_size;
3771 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3772 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3773 let q_proj = size_in * size_q / weight_pack_factor;
3774 let k_proj = size_in * size_kv / weight_pack_factor;
3775 let v_proj = size_in * size_kv / weight_pack_factor;
3776 let o_proj = size_q * size_in / weight_pack_factor;
3777
3778 let h_size = cfg.hidden_size;
3779 let i_size = cfg.intermediate_size;
3780 let gate_proj = h_size * i_size / weight_pack_factor;
3781 let up_proj = h_size * i_size / weight_pack_factor;
3782 let down_proj = i_size * h_size / weight_pack_factor;
3783
3784 input_layernorm
3785 + post_attention_layernorm
3786 + q_proj
3787 + k_proj
3788 + v_proj
3789 + o_proj
3790 + gate_proj
3791 + up_proj
3792 + down_proj
3793 };
3794 Ok(vec![
3795 per_layer_elems * dtype.size_in_bytes();
3796 cfg.num_hidden_layers
3797 ])
3798 }
3799
3800 fn num_layers(&self, config: &str) -> Result<usize> {
3801 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3802
3803 Ok(cfg.num_hidden_layers)
3804 }
3805 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3806 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3807
3808 let cfg = ModelConfigMetadata {
3809 max_seq_len: cfg.max_position_embeddings,
3810 num_layers: cfg.num_hidden_layers,
3811 hidden_size: cfg.hidden_size,
3812 num_kv_heads: cfg.num_key_value_heads,
3813 num_attn_heads: cfg.num_attention_heads,
3814 sliding_window: None,
3815 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3816 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3817 };
3818
3819 Ok(Box::new(cfg))
3820 }
3821}
3822
3823pub struct GraniteMoeHybridLoader;
3829
3830impl NormalModelLoader for GraniteMoeHybridLoader {
3831 fn load(
3832 &self,
3833 config: &str,
3834 vb: ShardedVarBuilder,
3835 normal_loading_metadata: NormalLoadingMetadata,
3836 attention_mechanism: AttentionImplementation,
3837 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3838 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3839
3840 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
3841 &cfg,
3842 vb,
3843 self.is_gptx(config)?,
3844 normal_loading_metadata,
3845 attention_mechanism,
3846 )?))
3847 }
3848 fn load_xlora(
3849 &self,
3850 _config: &str,
3851 _vb: ShardedVarBuilder,
3852 _lora_config: &[((String, String), LoraConfig)],
3853 _xlora_config: Option<XLoraConfig>,
3854 _xlora_ordering: Ordering,
3855 _normal_loading_metadata: NormalLoadingMetadata,
3856 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3857 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3858 todo!()
3859 }
3860 fn is_gptx(&self, _: &str) -> Result<bool> {
3861 Ok(true)
3862 }
3863 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3864 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3865 Ok(Box::new(cfg))
3866 }
3867}
3868
3869impl IsqModelLoader for GraniteMoeHybridLoader {
3870 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3871 Ok(vec![
3872 Regex::new(r"lm_head\.(weight|bias)$")?,
3873 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3875 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3876 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3877 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3878 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
3880 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
3881 ])
3882 }
3883 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3884 self.isq_layer_regexes(config)
3885 }
3886}
3887
3888impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
3889 fn mapped_max_act_size_elems(
3890 &self,
3891 config: &str,
3892 params: &AutoDeviceMapParams,
3893 ) -> Result<usize> {
3894 let AutoDeviceMapParams::Text {
3895 max_seq_len,
3896 max_batch_size,
3897 } = params
3898 else {
3899 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3900 };
3901
3902 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3903
3904 Ok(
3905 max_batch_size
3906 * cfg.num_attention_heads
3907 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3908 )
3909 }
3910 fn non_mapped_max_act_size_elems(
3911 &self,
3912 _config: &str,
3913 _params: &AutoDeviceMapParams,
3914 ) -> Result<usize> {
3915 Ok(0)
3916 }
3917
3918 fn non_mapped_size_in_bytes(
3919 &self,
3920 config: &str,
3921 dtype: DType,
3922 weight_pack_factor: usize,
3923 _matformer_config: Option<&MatformerSliceConfig>,
3924 ) -> Result<usize> {
3925 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3926
3927 let elems = {
3928 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3929 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3931 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3932 } else {
3933 0
3934 };
3935 let norm = cfg.hidden_size;
3936 embed_tokens + lm_head + norm
3937 };
3938 Ok(elems * dtype.size_in_bytes())
3939 }
3940
3941 fn layer_sizes_in_bytes(
3942 &self,
3943 config: &str,
3944 dtype: DType,
3945 weight_pack_factor: usize,
3946 _matformer_config: Option<&MatformerSliceConfig>,
3947 ) -> Result<Vec<usize>> {
3948 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3949
3950 let per_layer_elems = {
3951 let input_layernorm = cfg.hidden_size;
3952 let post_attention_layernorm = cfg.hidden_size;
3953
3954 let size_in = cfg.hidden_size;
3955 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3956 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
3957 let q_proj = size_in * size_q / weight_pack_factor;
3958 let k_proj = size_in * size_kv / weight_pack_factor;
3959 let v_proj = size_in * size_kv / weight_pack_factor;
3960 let o_proj = size_q * size_in / weight_pack_factor;
3961
3962 let h_size = cfg.hidden_size;
3963 let shared_i_size = cfg.shared_intermediate_size();
3964 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
3966 let output_linear = shared_i_size * h_size / weight_pack_factor;
3967
3968 input_layernorm
3969 + post_attention_layernorm
3970 + q_proj
3971 + k_proj
3972 + v_proj
3973 + o_proj
3974 + input_linear
3975 + output_linear
3976 };
3977 Ok(vec![
3978 per_layer_elems * dtype.size_in_bytes();
3979 cfg.num_hidden_layers
3980 ])
3981 }
3982
3983 fn num_layers(&self, config: &str) -> Result<usize> {
3984 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3985
3986 Ok(cfg.num_hidden_layers)
3987 }
3988 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3989 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
3990
3991 let cfg = ModelConfigMetadata {
3992 max_seq_len: cfg.max_position_embeddings,
3993 num_layers: cfg.num_hidden_layers,
3994 hidden_size: cfg.hidden_size,
3995 num_kv_heads: cfg.num_key_value_heads(),
3996 num_attn_heads: cfg.num_attention_heads,
3997 sliding_window: None,
3998 k_head_dim: cfg.head_dim(),
3999 v_head_dim: cfg.head_dim(),
4000 };
4001
4002 Ok(Box::new(cfg))
4003 }
4004}
4005
4006pub struct GptOssLoader;
4012
4013impl NormalModelLoader for GptOssLoader {
4014 fn load(
4015 &self,
4016 config: &str,
4017 vb: ShardedVarBuilder,
4018 normal_loading_metadata: NormalLoadingMetadata,
4019 attention_mechanism: AttentionImplementation,
4020 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4021 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4022
4023 Ok(Box::new(models::gpt_oss::Model::new(
4024 &cfg,
4025 vb,
4026 self.is_gptx(config)?,
4027 normal_loading_metadata,
4028 attention_mechanism,
4029 )?))
4030 }
4031 fn load_xlora(
4032 &self,
4033 _config: &str,
4034 _vb: ShardedVarBuilder,
4035 _lora_config: &[((String, String), LoraConfig)],
4036 _xlora_config: Option<XLoraConfig>,
4037 _xlora_ordering: Ordering,
4038 _normal_loading_metadata: NormalLoadingMetadata,
4039 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4040 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4041 anyhow::bail!("GPT-OSS does not support X-LoRA")
4042 }
4043 fn is_gptx(&self, _: &str) -> Result<bool> {
4044 Ok(true)
4045 }
4046 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4047 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4048 Ok(Box::new(cfg))
4049 }
4050 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4051 Ok(false)
4052 }
4053}
4054
4055impl IsqModelLoader for GptOssLoader {
4056 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4057 Ok(vec![
4059 Regex::new(r"lm_head\.(weight|bias)$")?,
4060 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4062 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4063 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4064 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4065 ])
4066 }
4067 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4068 self.isq_layer_regexes(config)
4069 }
4070}
4071
4072impl DeviceMappedModelLoader for GptOssLoader {
4073 fn mapped_max_act_size_elems(
4074 &self,
4075 config: &str,
4076 params: &AutoDeviceMapParams,
4077 ) -> Result<usize> {
4078 let AutoDeviceMapParams::Text {
4079 max_seq_len,
4080 max_batch_size,
4081 } = params
4082 else {
4083 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4084 };
4085
4086 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4087
4088 Ok(
4089 max_batch_size
4090 * cfg.num_attention_heads
4091 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4092 )
4093 }
4094 fn non_mapped_max_act_size_elems(
4095 &self,
4096 _config: &str,
4097 _params: &AutoDeviceMapParams,
4098 ) -> Result<usize> {
4099 Ok(0)
4100 }
4101
4102 fn non_mapped_size_in_bytes(
4103 &self,
4104 config: &str,
4105 dtype: DType,
4106 weight_pack_factor: usize,
4107 _matformer_config: Option<&MatformerSliceConfig>,
4108 ) -> Result<usize> {
4109 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4110
4111 let elems = {
4112 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4113 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4114 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4115 } else {
4116 0
4117 };
4118 let norm = cfg.hidden_size;
4119 embed_tokens + lm_head + norm
4120 };
4121 Ok(elems * dtype.size_in_bytes())
4122 }
4123
4124 fn layer_sizes_in_bytes(
4125 &self,
4126 config: &str,
4127 dtype: DType,
4128 weight_pack_factor: usize,
4129 _matformer_config: Option<&MatformerSliceConfig>,
4130 ) -> Result<Vec<usize>> {
4131 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4132
4133 let per_layer_elems = {
4134 let input_layernorm = cfg.hidden_size;
4135 let post_attention_layernorm = cfg.hidden_size;
4136
4137 let size_in = cfg.hidden_size;
4138 let head_dim = cfg.head_dim();
4139 let size_q = head_dim * cfg.num_attention_heads;
4140 let size_kv = head_dim * cfg.num_key_value_heads;
4141 let q_proj =
4142 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4143 let k_proj =
4144 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4145 let v_proj =
4146 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4147 let o_proj =
4148 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4149
4150 let mxfp4_pack = 2;
4155 let gate_up_proj_size =
4156 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4157 let down_proj_size =
4158 cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4159 let gate_up_scales =
4161 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4162 let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4163 let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4165 let down_bias = cfg.num_local_experts * cfg.hidden_size;
4166 let router = cfg.hidden_size * cfg.num_local_experts;
4168 let sinks = cfg.num_attention_heads;
4170
4171 input_layernorm
4172 + post_attention_layernorm
4173 + q_proj
4174 + k_proj
4175 + v_proj
4176 + o_proj
4177 + gate_up_proj_size
4178 + down_proj_size
4179 + gate_up_scales
4180 + down_scales
4181 + gate_up_bias
4182 + down_bias
4183 + router
4184 + sinks
4185 };
4186 Ok(vec![
4187 per_layer_elems * dtype.size_in_bytes();
4188 cfg.num_hidden_layers
4189 ])
4190 }
4191
4192 fn num_layers(&self, config: &str) -> Result<usize> {
4193 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4194
4195 Ok(cfg.num_hidden_layers)
4196 }
4197 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4198 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4199
4200 let head_dim = cfg.head_dim();
4201 let cfg = ModelConfigMetadata {
4202 max_seq_len: cfg.max_position_embeddings,
4203 num_layers: cfg.num_hidden_layers,
4204 hidden_size: cfg.hidden_size,
4205 num_kv_heads: cfg.num_key_value_heads,
4206 num_attn_heads: cfg.num_attention_heads,
4207 sliding_window: cfg.sliding_window,
4208 k_head_dim: head_dim,
4209 v_head_dim: head_dim,
4210 };
4211
4212 Ok(Box::new(cfg))
4213 }
4214}