1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 amoe::AnyMoeBaseModelMixin,
10 device_map::DeviceMapper,
11 lora::{LoraConfig, Ordering},
12 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
13 pipeline::{
14 isq::IsqModelLoader,
15 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16 EitherCache, IsqModel,
17 },
18 utils::varbuilder_utils::DeviceForLoadTensor,
19 xlora_models::NonGranularState,
20};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use mistralrs_quant::log::once_log_info;
24
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27#[cfg(feature = "pyo3_macros")]
28use pyo3::pyclass;
29
30use regex::Regex;
31use serde::Deserialize;
32
33use crate::{
34 models,
35 xlora_models::{self, XLoraConfig},
36};
37
38use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
39
40pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
41 #[allow(clippy::too_many_arguments)]
42 fn forward(
43 &self,
44 input_ids: &Tensor,
45 seqlen_offsets: &[usize],
46 context_lens: Vec<(usize, usize)>,
47 position_ids: Vec<usize>,
48 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
49 flash_params: &FlashParams,
50 ) -> candle_core::Result<Tensor>;
51 #[allow(clippy::too_many_arguments)]
52 fn xlora_forward(
53 &self,
54 input_ids: &Tensor,
55 input_ids_full: &Tensor,
56 seqlen_offsets: &[usize],
57 seqlen_offsets_full: &[usize],
58 no_kv_cache: bool,
59 non_granular_state: &Option<NonGranularState>,
60 context_lens: Vec<(usize, usize)>,
61 position_ids: Vec<usize>,
62 flash_params: &FlashParams,
63 flash_params_full: &FlashParams,
64 ) -> candle_core::Result<Tensor>;
65 fn is_xlora(&self) -> bool;
66 fn device(&self) -> &Device;
67 fn cache(&self) -> &EitherCache;
68 fn cache_mut(&mut self) -> &mut EitherCache;
69 fn max_seq_len(&self) -> usize;
70 fn config(&self) -> &ModelConfigMetadata;
71}
72
73pub struct NormalLoadingMetadata {
75 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
77 pub loading_isq: bool,
79 pub real_device: Device,
81 pub multi_progress: Arc<MultiProgress>,
83}
84
85pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
86 fn load(
87 &self,
88 config: &str,
89 vb: ShardedVarBuilder,
90 normal_loading_metadata: NormalLoadingMetadata,
91 attention_mechanism: AttentionImplementation,
92 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
93 #[allow(clippy::too_many_arguments)]
94 fn load_xlora(
95 &self,
96 config: &str,
97 vb: ShardedVarBuilder,
98 lora_config: &[((String, String), LoraConfig)],
99 xlora_config: Option<XLoraConfig>,
100 xlora_ordering: Ordering,
101 normal_loading_metadata: NormalLoadingMetadata,
102 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
103 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
104 fn is_gptx(&self, config: &str) -> Result<bool>;
105 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
106 Ok(true)
107 }
108 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
109 fn get_device_for_tensor(
110 &self,
111 config: &str,
112 _mapper: &dyn DeviceMapper,
113 loading_isq: bool,
114 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115 if loading_isq {
116 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117 } else {
118 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119 let num_layers = self.model_config(config)?.num_layers();
120 let closure = move |name: String| {
121 if let Some(captures) = re.captures(&name) {
122 captures
123 .get(1)
124 .and_then(|m| m.as_str().parse::<usize>().ok())
125 .map(|l| l.min(num_layers))
126 .map(DeviceForLoadTensor::Idx)
127 .unwrap_or(DeviceForLoadTensor::Base)
128 } else {
129 DeviceForLoadTensor::Base
130 }
131 };
132
133 Ok(Arc::new(closure))
134 }
135 }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140pub enum NormalLoaderType {
142 #[serde(rename = "mistral")]
143 Mistral,
144 #[serde(rename = "gemma")]
145 Gemma,
146 #[serde(rename = "mixtral")]
147 Mixtral,
148 #[serde(rename = "llama")]
149 Llama,
150 #[serde(rename = "phi2")]
151 Phi2,
152 #[serde(rename = "phi3")]
153 Phi3,
154 #[serde(rename = "qwen2")]
155 Qwen2,
156 #[serde(rename = "gemma2")]
157 Gemma2,
158 #[serde(rename = "starcoder2")]
159 Starcoder2,
160 #[serde(rename = "phi3.5moe")]
161 Phi3_5MoE,
162 #[serde(rename = "deepseekv2")]
163 DeepSeekV2,
164 #[serde(rename = "deepseekv3")]
165 DeepSeekV3,
166 #[serde(rename = "qwen3")]
167 Qwen3,
168 #[serde(rename = "qwen3moe")]
169 Qwen3Moe,
170}
171
172impl NormalLoaderType {
174 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
175 match name {
176 "MistralForCausalLM" => Ok(Self::Mistral),
177 "MixtralForCausalLM" => Ok(Self::Mixtral),
178 "GemmaForCausalLM" => Ok(Self::Gemma),
179 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
180 "PhiForCausalLM" => Ok(Self::Phi2),
181 "Phi3ForCausalLM" => Ok(Self::Phi3),
182 "LlamaForCausalLM" => Ok(Self::Llama),
183 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
184 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
185 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
186 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
187 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
188 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
189 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
190 other => anyhow::bail!(
191 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
192 ),
193 }
194 }
195}
196
197impl FromStr for NormalLoaderType {
198 type Err = String;
199 fn from_str(s: &str) -> Result<Self, Self::Err> {
200 match s {
201 "mistral" => Ok(Self::Mistral),
202 "gemma" => Ok(Self::Gemma),
203 "mixtral" => Ok(Self::Mixtral),
204 "llama" => Ok(Self::Llama),
205 "phi2" => Ok(Self::Phi2),
206 "phi3" => Ok(Self::Phi3),
207 "qwen2" => Ok(Self::Qwen2),
208 "gemma2" => Ok(Self::Gemma2),
209 "starcoder2" => Ok(Self::Starcoder2),
210 "phi3.5moe" => Ok(Self::Phi3_5MoE),
211 "deepseekv2" => Ok(Self::DeepSeekV2),
212 "deepseekv3" => Ok(Self::DeepSeekV3),
213 "qwen3" => Ok(Self::DeepSeekV3),
214 "qwen3moe" => Ok(Self::Qwen3Moe),
215 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `qwen3moe`.")),
216 }
217 }
218}
219
220impl Display for NormalLoaderType {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 Self::Gemma => write!(f, "gemma"),
224 Self::Gemma2 => write!(f, "gemma2"),
225 Self::Llama => write!(f, "llama"),
226 Self::Mistral => write!(f, "mistral"),
227 Self::Mixtral => write!(f, "mixtral"),
228 Self::Phi2 => write!(f, "phi2"),
229 Self::Phi3 => write!(f, "phi3"),
230 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
231 Self::Qwen2 => write!(f, "qwen2"),
232 Self::Starcoder2 => write!(f, "starcoder2"),
233 Self::DeepSeekV2 => write!(f, "deepseekv2"),
234 Self::DeepSeekV3 => write!(f, "deepseekv3"),
235 Self::Qwen3 => write!(f, "qwen3"),
236 Self::Qwen3Moe => write!(f, "qwen3moe"),
237 }
238 }
239}
240
241macro_rules! bias_if {
242 ($cond:expr, $size:expr) => {
243 if $cond {
244 $size
245 } else {
246 0
247 }
248 };
249}
250
251pub struct AutoNormalLoader;
253
254#[derive(Deserialize)]
255struct AutoNormalLoaderConfig {
256 architectures: Vec<String>,
257}
258
259impl AutoNormalLoader {
260 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
261 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
262 if auto_cfg.architectures.len() != 1 {
263 anyhow::bail!("Expected to have one name for `architectures` config field.")
264 }
265
266 let name = &auto_cfg.architectures[0];
267
268 let tp = NormalLoaderType::from_causal_lm_name(name)?;
269
270 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
271
272 match tp {
273 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
274 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
275 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
276 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
277 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
278 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
279 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
280 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
281 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
282 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
283 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
284 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
285 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
286 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
287 }
288 }
289}
290
291impl NormalModelLoader for AutoNormalLoader {
292 fn load(
293 &self,
294 config: &str,
295 vb: ShardedVarBuilder,
296 normal_loading_metadata: NormalLoadingMetadata,
297 attention_mechanism: AttentionImplementation,
298 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
299 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
300 }
301 fn load_xlora(
302 &self,
303 config: &str,
304 vb: ShardedVarBuilder,
305 lora_config: &[((String, String), LoraConfig)],
306 xlora_config: Option<XLoraConfig>,
307 xlora_ordering: Ordering,
308 normal_loading_metadata: NormalLoadingMetadata,
309 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
310 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
311 Self::get_loader(config)?.load_xlora(
312 config,
313 vb,
314 lora_config,
315 xlora_config,
316 xlora_ordering,
317 normal_loading_metadata,
318 preload_adapters,
319 )
320 }
321 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
322 Self::get_loader(config)?.get_config_repr(config)
323 }
324 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
325 Self::get_loader(config)?.supports_paged_attention(config)
326 }
327 fn is_gptx(&self, config: &str) -> Result<bool> {
328 Self::get_loader(config)?.is_gptx(config)
329 }
330}
331
332impl IsqModelLoader for AutoNormalLoader {
333 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
334 Self::get_loader(config)?.immediate_isq_predicates(config)
335 }
336 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
337 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
338 }
339 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
340 Self::get_loader(config)?.isq_layer_regexes(config)
341 }
342 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
343 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
344 }
345}
346
347impl DeviceMappedModelLoader for AutoNormalLoader {
348 fn non_mapped_size_in_bytes(
349 &self,
350 config: &str,
351 dtype: DType,
352 weight_pack_factor: usize,
353 ) -> Result<usize> {
354 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
355 }
356 fn num_layers(&self, config: &str) -> Result<usize> {
357 Self::get_loader(config)?.num_layers(config)
358 }
359 fn layer_sizes_in_bytes(
360 &self,
361 config: &str,
362 dtype: DType,
363 weight_pack_factor: usize,
364 ) -> Result<Vec<usize>> {
365 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
366 }
367 fn mapped_max_act_size_elems(
368 &self,
369 config: &str,
370 params: &super::AutoDeviceMapParams,
371 prompt_chunksize: usize,
372 ) -> Result<usize> {
373 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
374 }
375 fn non_mapped_max_act_size_elems(
376 &self,
377 _config: &str,
378 _params: &AutoDeviceMapParams,
379 ) -> Result<usize> {
380 Ok(0)
381 }
382 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
383 Self::get_loader(config)?.model_config(config)
384 }
385}
386
387pub struct MistralLoader;
390
391impl NormalModelLoader for MistralLoader {
392 fn load(
393 &self,
394 config: &str,
395 vb: ShardedVarBuilder,
396 normal_loading_metadata: NormalLoadingMetadata,
397 attention_mechanism: AttentionImplementation,
398 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
399 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
400 Ok(Box::new(models::mistral::Model::new(
401 &cfg,
402 vb,
403 self.is_gptx(config)?,
404 normal_loading_metadata,
405 attention_mechanism,
406 )?))
407 }
408 fn load_xlora(
409 &self,
410 config: &str,
411 vb: ShardedVarBuilder,
412 lora_config: &[((String, String), LoraConfig)],
413 xlora_config: Option<XLoraConfig>,
414 xlora_ordering: Ordering,
415 normal_loading_metadata: NormalLoadingMetadata,
416 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
417 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
418 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
419 Ok(Box::new(xlora_models::XLoraMistral::new(
420 &cfg,
421 vb,
422 lora_config,
423 xlora_config,
424 xlora_ordering,
425 self.is_gptx(config)?,
426 normal_loading_metadata,
427 preload_adapters,
428 )?))
429 }
430 fn is_gptx(&self, _: &str) -> Result<bool> {
431 Ok(true)
432 }
433 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
434 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
435 Ok(Box::new(cfg))
436 }
437}
438
439impl IsqModelLoader for MistralLoader {
440 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
441 Ok(vec![
442 Regex::new(r"lm_head\.(weight|bias)$")?,
443 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
445 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
446 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
447 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
448 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
450 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
451 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
452 ])
453 }
454 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
455 self.isq_layer_regexes(config)
456 }
457}
458
459impl DeviceMappedModelLoader for MistralLoader {
460 fn mapped_max_act_size_elems(
461 &self,
462 config: &str,
463 params: &AutoDeviceMapParams,
464 prompt_chunksize: usize,
465 ) -> Result<usize> {
466 let AutoDeviceMapParams::Text {
467 max_seq_len: _,
468 max_batch_size,
469 } = params
470 else {
471 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
472 };
473
474 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
475
476 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
477 }
478 fn non_mapped_max_act_size_elems(
479 &self,
480 _config: &str,
481 _params: &AutoDeviceMapParams,
482 ) -> Result<usize> {
483 Ok(0)
484 }
485
486 fn non_mapped_size_in_bytes(
487 &self,
488 config: &str,
489 dtype: DType,
490 weight_pack_factor: usize,
491 ) -> Result<usize> {
492 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
493
494 let elems = {
495 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
496 let lm_head = if !cfg.tie_word_embeddings {
497 cfg.hidden_size * cfg.vocab_size
498 } else {
499 0
500 };
501 let norm = cfg.hidden_size;
502 embed_tokens + lm_head + norm
503 };
504 Ok(elems * dtype.size_in_bytes())
505 }
506
507 fn layer_sizes_in_bytes(
508 &self,
509 config: &str,
510 dtype: DType,
511 weight_pack_factor: usize,
512 ) -> Result<Vec<usize>> {
513 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
514
515 let per_layer_elems = {
516 let input_layernorm = cfg.hidden_size;
517 let post_attention_layernorm = cfg.hidden_size;
518
519 let size_in = cfg.hidden_size;
520 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
521 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
522 let q_proj = size_in * size_q / weight_pack_factor;
523 let k_proj = size_in * size_kv / weight_pack_factor;
524 let v_proj = size_in * size_kv / weight_pack_factor;
525 let o_proj = size_q * size_in / weight_pack_factor;
526
527 let h_size = cfg.hidden_size;
528 let i_size = cfg.intermediate_size;
529 let gate_proj = h_size * i_size / weight_pack_factor;
530 let up_proj = h_size * i_size / weight_pack_factor;
531 let down_proj = i_size * h_size / weight_pack_factor;
532
533 input_layernorm
534 + post_attention_layernorm
535 + q_proj
536 + k_proj
537 + v_proj
538 + o_proj
539 + gate_proj
540 + up_proj
541 + down_proj
542 };
543 Ok(vec![
544 per_layer_elems * dtype.size_in_bytes();
545 cfg.num_hidden_layers
546 ])
547 }
548
549 fn num_layers(&self, config: &str) -> Result<usize> {
550 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
551 Ok(cfg.num_hidden_layers)
552 }
553
554 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
555 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
556
557 let cfg = ModelConfigMetadata {
558 max_seq_len: cfg.max_position_embeddings,
559 num_layers: cfg.num_hidden_layers,
560 hidden_size: cfg.hidden_size,
561 num_kv_heads: cfg.num_key_value_heads,
562 num_attn_heads: cfg.num_attention_heads,
563 sliding_window: cfg.sliding_window,
564 k_head_dim: cfg.head_dim(),
565 v_head_dim: cfg.head_dim(),
566 };
567
568 Ok(Box::new(cfg))
569 }
570}
571
572pub struct GemmaLoader;
578
579impl NormalModelLoader for GemmaLoader {
580 fn load(
581 &self,
582 config: &str,
583 vb: ShardedVarBuilder,
584 normal_loading_metadata: NormalLoadingMetadata,
585 attention_mechanism: AttentionImplementation,
586 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
587 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
588
589 Ok(Box::new(models::gemma::Model::new(
590 &cfg,
591 vb,
592 self.is_gptx(config)?,
593 normal_loading_metadata,
594 attention_mechanism,
595 )?))
596 }
597 fn load_xlora(
598 &self,
599 config: &str,
600 vb: ShardedVarBuilder,
601 lora_config: &[((String, String), LoraConfig)],
602 xlora_config: Option<XLoraConfig>,
603 xlora_ordering: Ordering,
604 normal_loading_metadata: NormalLoadingMetadata,
605 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
606 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
607 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
608
609 Ok(Box::new(xlora_models::XLoraGemma::new(
610 &cfg,
611 vb,
612 lora_config,
613 xlora_config,
614 xlora_ordering,
615 self.is_gptx(config)?,
616 normal_loading_metadata,
617 preload_adapters,
618 )?))
619 }
620 fn is_gptx(&self, _: &str) -> Result<bool> {
621 Ok(true)
622 }
623 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
624 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
625 Ok(Box::new(cfg))
626 }
627}
628
629impl IsqModelLoader for GemmaLoader {
630 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
631 Ok(vec![
632 Regex::new(r"lm_head\.(weight|bias)$")?,
633 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
635 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
636 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
637 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
638 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
640 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
641 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
642 ])
643 }
644 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
645 self.isq_layer_regexes(config)
646 }
647}
648
649impl DeviceMappedModelLoader for GemmaLoader {
650 fn mapped_max_act_size_elems(
651 &self,
652 config: &str,
653 params: &AutoDeviceMapParams,
654 prompt_chunksize: usize,
655 ) -> Result<usize> {
656 let AutoDeviceMapParams::Text {
657 max_seq_len: _,
658 max_batch_size,
659 } = params
660 else {
661 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
662 };
663
664 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
665
666 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
667 }
668 fn non_mapped_max_act_size_elems(
669 &self,
670 _config: &str,
671 _params: &AutoDeviceMapParams,
672 ) -> Result<usize> {
673 Ok(0)
674 }
675
676 fn non_mapped_size_in_bytes(
677 &self,
678 config: &str,
679 dtype: DType,
680 weight_pack_factor: usize,
681 ) -> Result<usize> {
682 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
683
684 let elems = {
685 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
686 let lm_head = if !cfg.tie_word_embeddings {
687 cfg.hidden_size * cfg.vocab_size
688 } else {
689 0
690 };
691 let norm = cfg.hidden_size;
692 embed_tokens + lm_head + norm
693 };
694 Ok(elems * dtype.size_in_bytes())
695 }
696
697 fn layer_sizes_in_bytes(
698 &self,
699 config: &str,
700 dtype: DType,
701 weight_pack_factor: usize,
702 ) -> Result<Vec<usize>> {
703 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
704
705 let per_layer_elems = {
706 let input_layernorm = cfg.hidden_size;
707 let post_attention_layernorm = cfg.hidden_size;
708
709 let size_in = cfg.hidden_size;
710 let size_q = cfg.head_dim * cfg.num_attention_heads;
711 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
712 let q_proj =
713 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
714 let k_proj =
715 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
716 let v_proj =
717 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
718 let o_proj =
719 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
720
721 let h_size = cfg.hidden_size;
722 let i_size = cfg.intermediate_size;
723 let gate_proj = h_size * i_size / weight_pack_factor;
724 let up_proj = h_size * i_size / weight_pack_factor;
725 let down_proj = i_size * h_size / weight_pack_factor;
726
727 input_layernorm
728 + post_attention_layernorm
729 + q_proj
730 + k_proj
731 + v_proj
732 + o_proj
733 + gate_proj
734 + up_proj
735 + down_proj
736 };
737 Ok(vec![
738 per_layer_elems * dtype.size_in_bytes();
739 cfg.num_hidden_layers
740 ])
741 }
742
743 fn num_layers(&self, config: &str) -> Result<usize> {
744 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
745 Ok(cfg.num_hidden_layers)
746 }
747
748 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
749 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
750
751 let cfg = ModelConfigMetadata {
752 max_seq_len: cfg.max_position_embeddings,
753 num_layers: cfg.num_hidden_layers,
754 hidden_size: cfg.hidden_size,
755 num_kv_heads: cfg.num_key_value_heads,
756 num_attn_heads: cfg.num_attention_heads,
757 sliding_window: None,
758 k_head_dim: cfg.head_dim,
759 v_head_dim: cfg.head_dim,
760 };
761
762 Ok(Box::new(cfg))
763 }
764}
765
766pub struct LlamaLoader;
772
773impl NormalModelLoader for LlamaLoader {
774 fn load(
775 &self,
776 config: &str,
777 vb: ShardedVarBuilder,
778 normal_loading_metadata: NormalLoadingMetadata,
779 attention_mechanism: AttentionImplementation,
780 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
781 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
782
783 Ok(Box::new(models::llama::Llama::new(
784 &cfg,
785 vb,
786 self.is_gptx(config)?,
787 normal_loading_metadata,
788 attention_mechanism,
789 )?))
790 }
791 fn load_xlora(
792 &self,
793 config: &str,
794 vb: ShardedVarBuilder,
795 lora_config: &[((String, String), LoraConfig)],
796 xlora_config: Option<XLoraConfig>,
797 xlora_ordering: Ordering,
798 normal_loading_metadata: NormalLoadingMetadata,
799 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
800 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
801 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
802
803 Ok(Box::new(xlora_models::XLoraLlama::new(
804 &cfg,
805 vb,
806 lora_config,
807 xlora_config,
808 xlora_ordering,
809 self.is_gptx(config)?,
810 normal_loading_metadata,
811 preload_adapters,
812 )?))
813 }
814 fn is_gptx(&self, _: &str) -> Result<bool> {
815 Ok(true)
816 }
817 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
818 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
819 Ok(Box::new(cfg))
820 }
821}
822
823impl IsqModelLoader for LlamaLoader {
824 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
825 Ok(vec![
826 Regex::new(r"lm_head\.(weight|bias)$")?,
827 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
829 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
830 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
831 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
832 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
834 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
835 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
836 ])
837 }
838 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
839 self.isq_layer_regexes(config)
840 }
841}
842
843impl DeviceMappedModelLoader for LlamaLoader {
844 fn mapped_max_act_size_elems(
845 &self,
846 config: &str,
847 params: &AutoDeviceMapParams,
848 prompt_chunksize: usize,
849 ) -> Result<usize> {
850 let AutoDeviceMapParams::Text {
851 max_seq_len: _,
852 max_batch_size,
853 } = params
854 else {
855 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
856 };
857
858 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
859
860 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
861 }
862 fn non_mapped_max_act_size_elems(
863 &self,
864 _config: &str,
865 _params: &AutoDeviceMapParams,
866 ) -> Result<usize> {
867 Ok(0)
868 }
869
870 fn non_mapped_size_in_bytes(
871 &self,
872 config: &str,
873 dtype: DType,
874 weight_pack_factor: usize,
875 ) -> Result<usize> {
876 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
877
878 let elems = {
879 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
880 let lm_head = if !cfg.tie_word_embeddings {
881 cfg.hidden_size * cfg.vocab_size
882 } else {
883 0
884 };
885 let norm = cfg.hidden_size;
886 embed_tokens + lm_head + norm
887 };
888 Ok(elems * dtype.size_in_bytes())
889 }
890
891 fn layer_sizes_in_bytes(
892 &self,
893 config: &str,
894 dtype: DType,
895 weight_pack_factor: usize,
896 ) -> Result<Vec<usize>> {
897 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
898
899 let per_layer_elems = {
900 let input_layernorm = cfg.hidden_size;
901 let post_attention_layernorm = cfg.hidden_size;
902
903 let size_in = cfg.hidden_size;
904 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
905 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
906 let q_proj = size_in * size_q / weight_pack_factor;
907 let k_proj = size_in * size_kv / weight_pack_factor;
908 let v_proj = size_in * size_kv / weight_pack_factor;
909 let o_proj = size_q * size_in / weight_pack_factor;
910
911 let h_size = cfg.hidden_size;
912 let i_size = cfg.intermediate_size;
913 let gate_proj = h_size * i_size / weight_pack_factor;
914 let up_proj = h_size * i_size / weight_pack_factor;
915 let down_proj = i_size * h_size / weight_pack_factor;
916
917 input_layernorm
918 + post_attention_layernorm
919 + q_proj
920 + k_proj
921 + v_proj
922 + o_proj
923 + gate_proj
924 + up_proj
925 + down_proj
926 };
927 Ok(vec![
928 per_layer_elems * dtype.size_in_bytes();
929 cfg.num_hidden_layers
930 ])
931 }
932
933 fn num_layers(&self, config: &str) -> Result<usize> {
934 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
935
936 Ok(cfg.num_hidden_layers)
937 }
938 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
939 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
940
941 let cfg = ModelConfigMetadata {
942 max_seq_len: cfg.max_position_embeddings,
943 num_layers: cfg.num_hidden_layers,
944 hidden_size: cfg.hidden_size,
945 num_kv_heads: cfg.num_key_value_heads,
946 num_attn_heads: cfg.num_attention_heads,
947 sliding_window: None,
948 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
949 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
950 };
951
952 Ok(Box::new(cfg))
953 }
954}
955
956pub struct MixtralLoader;
959
960impl NormalModelLoader for MixtralLoader {
961 fn load(
962 &self,
963 config: &str,
964 vb: ShardedVarBuilder,
965 normal_loading_metadata: NormalLoadingMetadata,
966 attention_mechanism: AttentionImplementation,
967 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
968 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
969
970 Ok(Box::new(models::mixtral::Model::new(
971 &cfg,
972 vb,
973 self.is_gptx(config)?,
974 normal_loading_metadata,
975 attention_mechanism,
976 )?))
977 }
978 fn load_xlora(
979 &self,
980 config: &str,
981 vb: ShardedVarBuilder,
982 lora_config: &[((String, String), LoraConfig)],
983 xlora_config: Option<XLoraConfig>,
984 xlora_ordering: Ordering,
985 normal_loading_metadata: NormalLoadingMetadata,
986 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
987 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
988 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
989
990 Ok(Box::new(xlora_models::XLoraMixtral::new(
991 &cfg,
992 vb,
993 lora_config,
994 xlora_config,
995 xlora_ordering,
996 self.is_gptx(config)?,
997 normal_loading_metadata,
998 preload_adapters,
999 )?))
1000 }
1001 fn is_gptx(&self, _: &str) -> Result<bool> {
1002 Ok(true)
1003 }
1004 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1005 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1006
1007 Ok(Box::new(cfg))
1008 }
1009}
1010
1011impl IsqModelLoader for MixtralLoader {
1012 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1013 Ok(vec![
1014 Regex::new(r"lm_head\.(weight|bias)$")?,
1015 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1017 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1018 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1019 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1020 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1022 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1023 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1024 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1025 ])
1026 }
1027 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1028 self.isq_layer_regexes(config)
1029 }
1030}
1031
1032impl DeviceMappedModelLoader for MixtralLoader {
1033 fn mapped_max_act_size_elems(
1034 &self,
1035 config: &str,
1036 params: &AutoDeviceMapParams,
1037 prompt_chunksize: usize,
1038 ) -> Result<usize> {
1039 let AutoDeviceMapParams::Text {
1040 max_seq_len: _,
1041 max_batch_size,
1042 } = params
1043 else {
1044 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1045 };
1046
1047 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1048
1049 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1050 }
1051 fn non_mapped_max_act_size_elems(
1052 &self,
1053 _config: &str,
1054 _params: &AutoDeviceMapParams,
1055 ) -> Result<usize> {
1056 Ok(0)
1057 }
1058
1059 fn non_mapped_size_in_bytes(
1060 &self,
1061 config: &str,
1062 dtype: DType,
1063 weight_pack_factor: usize,
1064 ) -> Result<usize> {
1065 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1066
1067 let elems = {
1068 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1069 let lm_head = if !cfg.tie_word_embeddings {
1070 cfg.hidden_size * cfg.vocab_size
1071 } else {
1072 0
1073 };
1074 let norm = cfg.hidden_size;
1075 embed_tokens + lm_head + norm
1076 };
1077 Ok(elems * dtype.size_in_bytes())
1078 }
1079
1080 fn layer_sizes_in_bytes(
1081 &self,
1082 config: &str,
1083 dtype: DType,
1084 weight_pack_factor: usize,
1085 ) -> Result<Vec<usize>> {
1086 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1087
1088 let per_layer_elems = {
1089 let input_layernorm = cfg.hidden_size;
1090 let post_attention_layernorm = cfg.hidden_size;
1091
1092 let size_in = cfg.hidden_size;
1093 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1094 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1095 let q_proj = size_in * size_q / weight_pack_factor;
1096 let k_proj = size_in * size_kv / weight_pack_factor;
1097 let v_proj = size_in * size_kv / weight_pack_factor;
1098 let o_proj = size_q * size_in / weight_pack_factor;
1099
1100 let moe_block = {
1101 let gate = cfg.hidden_size * cfg.num_local_experts;
1102 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1104 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1105 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1106 gate + cfg.num_local_experts * w1
1107 + cfg.num_local_experts * w2
1108 + cfg.num_local_experts * w3
1109 };
1110
1111 input_layernorm
1112 + post_attention_layernorm
1113 + q_proj
1114 + k_proj
1115 + v_proj
1116 + o_proj
1117 + moe_block
1118 };
1119 Ok(vec![
1120 per_layer_elems * dtype.size_in_bytes();
1121 cfg.num_hidden_layers
1122 ])
1123 }
1124
1125 fn num_layers(&self, config: &str) -> Result<usize> {
1126 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1127
1128 Ok(cfg.num_hidden_layers)
1129 }
1130
1131 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1132 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1133
1134 let cfg = ModelConfigMetadata {
1135 max_seq_len: cfg.max_position_embeddings,
1136 num_layers: cfg.num_hidden_layers,
1137 hidden_size: cfg.hidden_size,
1138 num_kv_heads: cfg.num_key_value_heads,
1139 num_attn_heads: cfg.num_attention_heads,
1140 sliding_window: cfg.sliding_window,
1141 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1142 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1143 };
1144
1145 Ok(Box::new(cfg))
1146 }
1147}
1148
1149pub struct Phi2Loader;
1155
1156impl NormalModelLoader for Phi2Loader {
1157 fn load(
1158 &self,
1159 config: &str,
1160 vb: ShardedVarBuilder,
1161 normal_loading_metadata: NormalLoadingMetadata,
1162 attention_mechanism: AttentionImplementation,
1163 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1164 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1165
1166 Ok(Box::new(models::phi2::Model::new(
1167 &cfg,
1168 vb,
1169 self.is_gptx(config)?,
1170 normal_loading_metadata,
1171 attention_mechanism,
1172 )?))
1173 }
1174 fn load_xlora(
1175 &self,
1176 config: &str,
1177 vb: ShardedVarBuilder,
1178 lora_config: &[((String, String), LoraConfig)],
1179 xlora_config: Option<XLoraConfig>,
1180 xlora_ordering: Ordering,
1181 normal_loading_metadata: NormalLoadingMetadata,
1182 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1183 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1184 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1185
1186 Ok(Box::new(xlora_models::XLoraPhi2::new(
1187 &cfg,
1188 vb,
1189 lora_config,
1190 xlora_config,
1191 xlora_ordering,
1192 self.is_gptx(config)?,
1193 normal_loading_metadata,
1194 preload_adapters,
1195 )?))
1196 }
1197 fn is_gptx(&self, _: &str) -> Result<bool> {
1198 Ok(true)
1199 }
1200 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1201 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1202
1203 Ok(Box::new(cfg))
1204 }
1205}
1206
1207impl IsqModelLoader for Phi2Loader {
1208 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1209 Ok(vec![
1210 Regex::new(r"lm_head\.(weight|bias)$")?,
1211 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1213 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1214 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1215 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1216 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1218 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1219 ])
1220 }
1221 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1222 self.isq_layer_regexes(config)
1223 }
1224}
1225
1226impl DeviceMappedModelLoader for Phi2Loader {
1227 fn mapped_max_act_size_elems(
1228 &self,
1229 config: &str,
1230 params: &AutoDeviceMapParams,
1231 prompt_chunksize: usize,
1232 ) -> Result<usize> {
1233 let AutoDeviceMapParams::Text {
1234 max_seq_len: _,
1235 max_batch_size,
1236 } = params
1237 else {
1238 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1239 };
1240
1241 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1242
1243 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1244 }
1245 fn non_mapped_max_act_size_elems(
1246 &self,
1247 _config: &str,
1248 _params: &AutoDeviceMapParams,
1249 ) -> Result<usize> {
1250 Ok(0)
1251 }
1252
1253 fn non_mapped_size_in_bytes(
1254 &self,
1255 config: &str,
1256 dtype: DType,
1257 weight_pack_factor: usize,
1258 ) -> Result<usize> {
1259 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1260
1261 let elems = {
1262 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1263 let lm_head = if !cfg.tie_word_embeddings {
1264 cfg.hidden_size * cfg.vocab_size
1265 } else {
1266 0
1267 };
1268 let norm = cfg.hidden_size;
1269 embed_tokens + lm_head + norm
1270 };
1271 Ok(elems * dtype.size_in_bytes())
1272 }
1273
1274 fn layer_sizes_in_bytes(
1275 &self,
1276 config: &str,
1277 dtype: DType,
1278 weight_pack_factor: usize,
1279 ) -> Result<Vec<usize>> {
1280 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1281
1282 let per_layer_elems = {
1283 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1284
1285 let size_in = cfg.hidden_size;
1286 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1287 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1288 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1289 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1290 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1291 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1292 let (q_norm, k_norm) = if cfg.qk_layernorm {
1293 (cfg.head_dim(), cfg.head_dim())
1294 } else {
1295 (0, 0)
1296 };
1297
1298 let h_size = cfg.hidden_size;
1299 let i_size = cfg.intermediate_size;
1300 let fc1 = h_size * i_size / weight_pack_factor;
1301 let fc2 = h_size * i_size / weight_pack_factor;
1302
1303 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1304 };
1305 Ok(vec![
1306 per_layer_elems * dtype.size_in_bytes();
1307 cfg.num_hidden_layers
1308 ])
1309 }
1310
1311 fn num_layers(&self, config: &str) -> Result<usize> {
1312 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1313
1314 Ok(cfg.num_hidden_layers)
1315 }
1316
1317 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1318 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1319
1320 let cfg = ModelConfigMetadata {
1321 max_seq_len: cfg.max_position_embeddings,
1322 num_layers: cfg.num_hidden_layers,
1323 hidden_size: cfg.hidden_size,
1324 num_kv_heads: cfg.num_key_value_heads(),
1325 num_attn_heads: cfg.num_attention_heads,
1326 sliding_window: None,
1327 k_head_dim: cfg.head_dim(),
1328 v_head_dim: cfg.head_dim(),
1329 };
1330
1331 Ok(Box::new(cfg))
1332 }
1333}
1334
1335pub struct Phi3Loader;
1341
1342impl NormalModelLoader for Phi3Loader {
1343 fn load(
1344 &self,
1345 config: &str,
1346 vb: ShardedVarBuilder,
1347 normal_loading_metadata: NormalLoadingMetadata,
1348 attention_mechanism: AttentionImplementation,
1349 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1350 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1351
1352 Ok(Box::new(models::phi3::Model::new(
1353 &cfg,
1354 vb,
1355 self.is_gptx(config)?,
1356 normal_loading_metadata,
1357 attention_mechanism,
1358 )?))
1359 }
1360 fn load_xlora(
1361 &self,
1362 config: &str,
1363 vb: ShardedVarBuilder,
1364 lora_config: &[((String, String), LoraConfig)],
1365 xlora_config: Option<XLoraConfig>,
1366 xlora_ordering: Ordering,
1367 normal_loading_metadata: NormalLoadingMetadata,
1368 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1369 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1370 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1371
1372 Ok(Box::new(xlora_models::XLoraPhi3::new(
1373 &cfg,
1374 vb,
1375 lora_config,
1376 xlora_config,
1377 xlora_ordering,
1378 self.is_gptx(config)?,
1379 normal_loading_metadata,
1380 preload_adapters,
1381 )?))
1382 }
1383 fn is_gptx(&self, _: &str) -> Result<bool> {
1384 Ok(true)
1385 }
1386 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1387 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1388
1389 Ok(Box::new(cfg))
1390 }
1391}
1392
1393impl IsqModelLoader for Phi3Loader {
1394 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1395 Ok(vec![
1396 Regex::new(r"lm_head\.(weight|bias)$")?,
1397 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1399 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1400 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1402 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1403 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1404 ])
1405 }
1406 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1407 self.isq_layer_regexes(config)
1408 }
1409}
1410
1411impl DeviceMappedModelLoader for Phi3Loader {
1412 fn mapped_max_act_size_elems(
1413 &self,
1414 config: &str,
1415 params: &AutoDeviceMapParams,
1416 prompt_chunksize: usize,
1417 ) -> Result<usize> {
1418 let AutoDeviceMapParams::Text {
1419 max_seq_len: _,
1420 max_batch_size,
1421 } = params
1422 else {
1423 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1424 };
1425
1426 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1427
1428 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1429 }
1430 fn non_mapped_max_act_size_elems(
1431 &self,
1432 _config: &str,
1433 _params: &AutoDeviceMapParams,
1434 ) -> Result<usize> {
1435 Ok(0)
1436 }
1437
1438 fn non_mapped_size_in_bytes(
1439 &self,
1440 config: &str,
1441 dtype: DType,
1442 weight_pack_factor: usize,
1443 ) -> Result<usize> {
1444 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1445
1446 let elems = {
1447 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1448 let lm_head = if !cfg.tie_word_embeddings {
1449 cfg.hidden_size * cfg.vocab_size
1450 } else {
1451 0
1452 };
1453 let norm = cfg.hidden_size;
1454 embed_tokens + lm_head + norm
1455 };
1456 Ok(elems * dtype.size_in_bytes())
1457 }
1458
1459 fn layer_sizes_in_bytes(
1460 &self,
1461 config: &str,
1462 dtype: DType,
1463 weight_pack_factor: usize,
1464 ) -> Result<Vec<usize>> {
1465 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1466
1467 let per_layer_elems = {
1468 let input_layernorm = cfg.hidden_size;
1469 let post_attention_layernorm = cfg.hidden_size;
1470
1471 let size_in = cfg.hidden_size;
1472 let head_dim = cfg.head_dim();
1473 let op_size =
1474 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1475 let qkv_proj = size_in * op_size / weight_pack_factor;
1476 let o_proj =
1477 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1478
1479 let h_size = cfg.hidden_size;
1480 let i_size = cfg.intermediate_size;
1481 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1482 let down_proj = h_size * i_size / weight_pack_factor;
1483
1484 input_layernorm
1485 + post_attention_layernorm
1486 + qkv_proj
1487 + o_proj
1488 + gate_up_proj
1489 + down_proj
1490 };
1491 Ok(vec![
1492 per_layer_elems * dtype.size_in_bytes();
1493 cfg.num_hidden_layers
1494 ])
1495 }
1496
1497 fn num_layers(&self, config: &str) -> Result<usize> {
1498 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1499
1500 Ok(cfg.num_hidden_layers)
1501 }
1502
1503 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1504 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1505
1506 let cfg = ModelConfigMetadata {
1507 max_seq_len: cfg.max_position_embeddings,
1508 num_layers: cfg.num_hidden_layers,
1509 hidden_size: cfg.hidden_size,
1510 num_kv_heads: cfg.num_key_value_heads,
1511 num_attn_heads: cfg.num_attention_heads,
1512 sliding_window: cfg.sliding_window,
1513 k_head_dim: cfg.head_dim(),
1514 v_head_dim: cfg.head_dim(),
1515 };
1516
1517 Ok(Box::new(cfg))
1518 }
1519}
1520
1521pub struct Qwen2Loader;
1527
1528impl NormalModelLoader for Qwen2Loader {
1529 fn load(
1530 &self,
1531 config: &str,
1532 vb: ShardedVarBuilder,
1533 normal_loading_metadata: NormalLoadingMetadata,
1534 attention_mechanism: AttentionImplementation,
1535 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1536 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1537
1538 Ok(Box::new(models::qwen2::Model::new(
1539 &cfg,
1540 vb,
1541 self.is_gptx(config)?,
1542 normal_loading_metadata,
1543 attention_mechanism,
1544 )?))
1545 }
1546 fn load_xlora(
1547 &self,
1548 _config: &str,
1549 _vb: ShardedVarBuilder,
1550 _lora_config: &[((String, String), LoraConfig)],
1551 _xlora_config: Option<XLoraConfig>,
1552 _xlora_ordering: Ordering,
1553 _normal_loading_metadata: NormalLoadingMetadata,
1554 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1555 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1556 todo!()
1557 }
1558 fn is_gptx(&self, _: &str) -> Result<bool> {
1559 Ok(true)
1560 }
1561 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1562 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1563
1564 Ok(Box::new(cfg))
1565 }
1566}
1567
1568impl IsqModelLoader for Qwen2Loader {
1569 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1570 Ok(vec![
1571 Regex::new(r"lm_head\.(weight|bias)$")?,
1572 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1574 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1575 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1576 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1577 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1579 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1580 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1581 ])
1582 }
1583 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1584 self.isq_layer_regexes(config)
1585 }
1586}
1587
1588impl DeviceMappedModelLoader for Qwen2Loader {
1589 fn mapped_max_act_size_elems(
1590 &self,
1591 config: &str,
1592 params: &AutoDeviceMapParams,
1593 prompt_chunksize: usize,
1594 ) -> Result<usize> {
1595 let AutoDeviceMapParams::Text {
1596 max_seq_len: _,
1597 max_batch_size,
1598 } = params
1599 else {
1600 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1601 };
1602
1603 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1604
1605 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1606 }
1607 fn non_mapped_max_act_size_elems(
1608 &self,
1609 _config: &str,
1610 _params: &AutoDeviceMapParams,
1611 ) -> Result<usize> {
1612 Ok(0)
1613 }
1614
1615 fn non_mapped_size_in_bytes(
1616 &self,
1617 config: &str,
1618 dtype: DType,
1619 weight_pack_factor: usize,
1620 ) -> Result<usize> {
1621 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1622
1623 let elems = {
1624 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1625 let lm_head = if !cfg.tie_word_embeddings {
1626 cfg.hidden_size * cfg.vocab_size
1627 } else {
1628 0
1629 };
1630 let norm = cfg.hidden_size;
1631 embed_tokens + lm_head + norm
1632 };
1633 Ok(elems * dtype.size_in_bytes())
1634 }
1635
1636 fn layer_sizes_in_bytes(
1637 &self,
1638 config: &str,
1639 dtype: DType,
1640 weight_pack_factor: usize,
1641 ) -> Result<Vec<usize>> {
1642 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1643
1644 let per_layer_elems = {
1645 let input_layernorm = cfg.hidden_size;
1646 let post_attention_layernorm = cfg.hidden_size;
1647
1648 let size_in = cfg.hidden_size;
1649 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1650 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1651 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1652 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1653 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1654 let o_proj = size_q * size_in / weight_pack_factor;
1655
1656 let h_size = cfg.hidden_size;
1657 let i_size = cfg.intermediate_size;
1658 let gate_proj = h_size * i_size / weight_pack_factor;
1659 let up_proj = h_size * i_size / weight_pack_factor;
1660 let down_proj = i_size * h_size / weight_pack_factor;
1661
1662 input_layernorm
1663 + post_attention_layernorm
1664 + q_proj
1665 + k_proj
1666 + v_proj
1667 + o_proj
1668 + gate_proj
1669 + up_proj
1670 + down_proj
1671 };
1672 Ok(vec![
1673 per_layer_elems * dtype.size_in_bytes();
1674 cfg.num_hidden_layers
1675 ])
1676 }
1677
1678 fn num_layers(&self, config: &str) -> Result<usize> {
1679 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1680
1681 Ok(cfg.num_hidden_layers)
1682 }
1683
1684 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1685 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1686
1687 let cfg = ModelConfigMetadata {
1688 max_seq_len: cfg.max_position_embeddings,
1689 num_layers: cfg.num_hidden_layers,
1690 hidden_size: cfg.hidden_size,
1691 num_kv_heads: cfg.num_key_value_heads,
1692 num_attn_heads: cfg.num_attention_heads,
1693 sliding_window: cfg.sliding_window,
1694 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1695 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1696 };
1697
1698 Ok(Box::new(cfg))
1699 }
1700}
1701
1702pub struct Gemma2Loader;
1708
1709impl NormalModelLoader for Gemma2Loader {
1710 fn load(
1711 &self,
1712 config: &str,
1713 vb: ShardedVarBuilder,
1714 normal_loading_metadata: NormalLoadingMetadata,
1715 attention_mechanism: AttentionImplementation,
1716 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1717 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1718
1719 Ok(Box::new(models::gemma2::Model::new(
1720 &cfg,
1721 vb,
1722 self.is_gptx(config)?,
1723 normal_loading_metadata,
1724 attention_mechanism,
1725 )?))
1726 }
1727 fn load_xlora(
1728 &self,
1729 config: &str,
1730 vb: ShardedVarBuilder,
1731 lora_config: &[((String, String), LoraConfig)],
1732 xlora_config: Option<XLoraConfig>,
1733 xlora_ordering: Ordering,
1734 normal_loading_metadata: NormalLoadingMetadata,
1735 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1736 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1737 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1738
1739 Ok(Box::new(xlora_models::XLoraGemma2::new(
1740 &cfg,
1741 vb,
1742 lora_config,
1743 xlora_config,
1744 xlora_ordering,
1745 self.is_gptx(config)?,
1746 normal_loading_metadata,
1747 preload_adapters,
1748 )?))
1749 }
1750 fn is_gptx(&self, _: &str) -> Result<bool> {
1751 Ok(true)
1752 }
1753 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1754 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1755
1756 Ok(Box::new(cfg))
1757 }
1758}
1759
1760impl IsqModelLoader for Gemma2Loader {
1761 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1762 Ok(vec![
1763 Regex::new(r"lm_head\.(weight|bias)$")?,
1764 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1766 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1767 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1768 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1769 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1771 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1772 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1773 ])
1774 }
1775 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1776 self.isq_layer_regexes(config)
1777 }
1778}
1779
1780impl DeviceMappedModelLoader for Gemma2Loader {
1781 fn mapped_max_act_size_elems(
1782 &self,
1783 config: &str,
1784 params: &AutoDeviceMapParams,
1785 prompt_chunksize: usize,
1786 ) -> Result<usize> {
1787 let AutoDeviceMapParams::Text {
1788 max_seq_len: _,
1789 max_batch_size,
1790 } = params
1791 else {
1792 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1793 };
1794
1795 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1796
1797 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1798 }
1799 fn non_mapped_max_act_size_elems(
1800 &self,
1801 _config: &str,
1802 _params: &AutoDeviceMapParams,
1803 ) -> Result<usize> {
1804 Ok(0)
1805 }
1806
1807 fn non_mapped_size_in_bytes(
1808 &self,
1809 config: &str,
1810 dtype: DType,
1811 weight_pack_factor: usize,
1812 ) -> Result<usize> {
1813 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1814
1815 let elems = {
1816 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1817 let lm_head = if !cfg.tie_word_embeddings {
1818 cfg.hidden_size * cfg.vocab_size
1819 } else {
1820 0
1821 };
1822 let norm = cfg.hidden_size;
1823 embed_tokens + lm_head + norm
1824 };
1825 Ok(elems * dtype.size_in_bytes())
1826 }
1827
1828 fn layer_sizes_in_bytes(
1829 &self,
1830 config: &str,
1831 dtype: DType,
1832 weight_pack_factor: usize,
1833 ) -> Result<Vec<usize>> {
1834 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1835
1836 let per_layer_elems = {
1837 let input_layernorm = cfg.hidden_size;
1838 let post_attention_layernorm = cfg.hidden_size;
1839
1840 let size_in = cfg.hidden_size;
1841 let size_q = cfg.head_dim * cfg.num_attention_heads;
1842 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1843 let q_proj =
1844 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1845 let k_proj =
1846 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1847 let v_proj =
1848 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1849 let o_proj =
1850 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1851
1852 let h_size = cfg.hidden_size;
1853 let i_size = cfg.intermediate_size;
1854 let gate_proj = h_size * i_size / weight_pack_factor;
1855 let up_proj = h_size * i_size / weight_pack_factor;
1856 let down_proj = i_size * h_size / weight_pack_factor;
1857
1858 input_layernorm
1859 + post_attention_layernorm
1860 + q_proj
1861 + k_proj
1862 + v_proj
1863 + o_proj
1864 + gate_proj
1865 + up_proj
1866 + down_proj
1867 };
1868 Ok(vec![
1869 per_layer_elems * dtype.size_in_bytes();
1870 cfg.num_hidden_layers
1871 ])
1872 }
1873
1874 fn num_layers(&self, config: &str) -> Result<usize> {
1875 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1876
1877 Ok(cfg.num_hidden_layers)
1878 }
1879 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1880 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1881
1882 let cfg = ModelConfigMetadata {
1883 max_seq_len: cfg.max_position_embeddings,
1884 num_layers: cfg.num_hidden_layers,
1885 hidden_size: cfg.hidden_size,
1886 num_kv_heads: cfg.num_key_value_heads,
1887 num_attn_heads: cfg.num_attention_heads,
1888 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1890 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1891 };
1892
1893 Ok(Box::new(cfg))
1894 }
1895}
1896
1897pub struct Starcoder2Loader;
1903
1904impl NormalModelLoader for Starcoder2Loader {
1905 fn load(
1906 &self,
1907 config: &str,
1908 vb: ShardedVarBuilder,
1909 normal_loading_metadata: NormalLoadingMetadata,
1910 attention_mechanism: AttentionImplementation,
1911 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1912 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1913
1914 Ok(Box::new(models::starcoder2::Model::new(
1915 &cfg,
1916 vb,
1917 self.is_gptx(config)?,
1918 normal_loading_metadata,
1919 attention_mechanism,
1920 )?))
1921 }
1922 fn load_xlora(
1923 &self,
1924 config: &str,
1925 vb: ShardedVarBuilder,
1926 lora_config: &[((String, String), LoraConfig)],
1927 xlora_config: Option<XLoraConfig>,
1928 xlora_ordering: Ordering,
1929 normal_loading_metadata: NormalLoadingMetadata,
1930 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1931 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1932 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1933
1934 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1935 &cfg,
1936 vb,
1937 lora_config,
1938 xlora_config,
1939 xlora_ordering,
1940 self.is_gptx(config)?,
1941 normal_loading_metadata,
1942 preload_adapters,
1943 )?))
1944 }
1945 fn is_gptx(&self, _: &str) -> Result<bool> {
1946 Ok(true)
1947 }
1948 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1949 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1950
1951 Ok(Box::new(cfg))
1952 }
1953}
1954
1955impl IsqModelLoader for Starcoder2Loader {
1956 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1957 Ok(vec![
1958 Regex::new(r"lm_head\.(weight|bias)$")?,
1959 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1961 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1962 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1963 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1964 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1966 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
1967 ])
1968 }
1969 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1970 self.isq_layer_regexes(config)
1971 }
1972}
1973
1974impl DeviceMappedModelLoader for Starcoder2Loader {
1975 fn mapped_max_act_size_elems(
1976 &self,
1977 config: &str,
1978 params: &AutoDeviceMapParams,
1979 prompt_chunksize: usize,
1980 ) -> Result<usize> {
1981 let AutoDeviceMapParams::Text {
1982 max_seq_len: _,
1983 max_batch_size,
1984 } = params
1985 else {
1986 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1987 };
1988
1989 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1990
1991 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1992 }
1993 fn non_mapped_max_act_size_elems(
1994 &self,
1995 _config: &str,
1996 _params: &AutoDeviceMapParams,
1997 ) -> Result<usize> {
1998 Ok(0)
1999 }
2000
2001 fn non_mapped_size_in_bytes(
2002 &self,
2003 config: &str,
2004 dtype: DType,
2005 weight_pack_factor: usize,
2006 ) -> Result<usize> {
2007 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2008
2009 let elems = {
2010 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2011 let lm_head = if !cfg.tie_word_embeddings {
2012 cfg.hidden_size * cfg.vocab_size
2013 } else {
2014 0
2015 };
2016 let norm = cfg.hidden_size + cfg.hidden_size;
2017 embed_tokens + lm_head + norm
2018 };
2019 Ok(elems * dtype.size_in_bytes())
2020 }
2021
2022 fn layer_sizes_in_bytes(
2023 &self,
2024 config: &str,
2025 dtype: DType,
2026 weight_pack_factor: usize,
2027 ) -> Result<Vec<usize>> {
2028 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2029
2030 let per_layer_elems = {
2031 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2032 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2033
2034 let size_in = cfg.hidden_size;
2035 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2036 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2037 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2038 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2039 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2040 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2041
2042 let h_size = cfg.hidden_size;
2043 let i_size = cfg.intermediate_size;
2044 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2045 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2046
2047 input_layernorm
2048 + post_attention_layernorm
2049 + q_proj
2050 + k_proj
2051 + v_proj
2052 + o_proj
2053 + fc1
2054 + fc2
2055 };
2056 Ok(vec![
2057 per_layer_elems * dtype.size_in_bytes();
2058 cfg.num_hidden_layers
2059 ])
2060 }
2061
2062 fn num_layers(&self, config: &str) -> Result<usize> {
2063 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2064
2065 Ok(cfg.num_hidden_layers)
2066 }
2067
2068 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2069 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071 let cfg = ModelConfigMetadata {
2072 max_seq_len: cfg.max_position_embeddings,
2073 num_layers: cfg.num_hidden_layers,
2074 hidden_size: cfg.hidden_size,
2075 num_kv_heads: cfg.num_key_value_heads,
2076 num_attn_heads: cfg.num_attention_heads,
2077 sliding_window: cfg.sliding_window,
2078 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2079 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2080 };
2081
2082 Ok(Box::new(cfg))
2083 }
2084}
2085
2086pub struct Phi3_5MoELoader;
2092
2093impl NormalModelLoader for Phi3_5MoELoader {
2094 fn load(
2095 &self,
2096 config: &str,
2097 vb: ShardedVarBuilder,
2098 normal_loading_metadata: NormalLoadingMetadata,
2099 attention_mechanism: AttentionImplementation,
2100 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2101 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2102
2103 Ok(Box::new(models::phi3_5_moe::Model::new(
2104 &cfg,
2105 vb,
2106 self.is_gptx(config)?,
2107 normal_loading_metadata,
2108 attention_mechanism,
2109 )?))
2110 }
2111 fn load_xlora(
2112 &self,
2113 config: &str,
2114 vb: ShardedVarBuilder,
2115 lora_config: &[((String, String), LoraConfig)],
2116 xlora_config: Option<XLoraConfig>,
2117 xlora_ordering: Ordering,
2118 normal_loading_metadata: NormalLoadingMetadata,
2119 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2120 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2121 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2122
2123 Ok(Box::new(xlora_models::XLoraPhi3::new(
2124 &cfg,
2125 vb,
2126 lora_config,
2127 xlora_config,
2128 xlora_ordering,
2129 self.is_gptx(config)?,
2130 normal_loading_metadata,
2131 preload_adapters,
2132 )?))
2133 }
2134 fn is_gptx(&self, _: &str) -> Result<bool> {
2135 Ok(true)
2136 }
2137 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2138 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2139
2140 Ok(Box::new(cfg))
2141 }
2142}
2143
2144impl IsqModelLoader for Phi3_5MoELoader {
2145 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2146 Ok(vec![
2147 Regex::new(r"lm_head\.(weight|bias)$")?,
2148 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2150 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2151 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2152 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2153 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2155 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2156 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2157 ])
2158 }
2159 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2160 self.isq_layer_regexes(config)
2161 }
2162
2163 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2164 Ok(vec![
2165 Regex::new(r"lm_head\.(weight|bias)$")?,
2166 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2168 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2169 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2170 ])
2171 }
2172 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2173 self.isq_layer_regexes_moqe(config)
2174 }
2175}
2176
2177impl DeviceMappedModelLoader for Phi3_5MoELoader {
2178 fn mapped_max_act_size_elems(
2179 &self,
2180 config: &str,
2181 params: &AutoDeviceMapParams,
2182 prompt_chunksize: usize,
2183 ) -> Result<usize> {
2184 let AutoDeviceMapParams::Text {
2185 max_seq_len: _,
2186 max_batch_size,
2187 } = params
2188 else {
2189 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2190 };
2191
2192 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2193
2194 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2195 }
2196 fn non_mapped_max_act_size_elems(
2197 &self,
2198 _config: &str,
2199 _params: &AutoDeviceMapParams,
2200 ) -> Result<usize> {
2201 Ok(0)
2202 }
2203
2204 fn non_mapped_size_in_bytes(
2205 &self,
2206 config: &str,
2207 dtype: DType,
2208 weight_pack_factor: usize,
2209 ) -> Result<usize> {
2210 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2211
2212 let elems = {
2213 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2214 let lm_head = if !cfg.tie_word_embeddings {
2215 cfg.hidden_size * cfg.vocab_size
2216 } else {
2217 0
2218 };
2219 let norm = cfg.hidden_size;
2220 embed_tokens + lm_head + norm
2221 };
2222 Ok(elems * dtype.size_in_bytes())
2223 }
2224
2225 fn layer_sizes_in_bytes(
2226 &self,
2227 config: &str,
2228 dtype: DType,
2229 weight_pack_factor: usize,
2230 ) -> Result<Vec<usize>> {
2231 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2232
2233 let per_layer_elems = {
2234 let input_layernorm = cfg.hidden_size;
2235 let post_attention_layernorm = cfg.hidden_size;
2236
2237 let size_in = cfg.hidden_size;
2238 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2239 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2240 let q_proj =
2241 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2242 let k_proj =
2243 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2244 let v_proj =
2245 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2246 let o_proj =
2247 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2248
2249 let moe_block = {
2250 let gate = cfg.hidden_size * cfg.num_local_experts;
2251 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2253 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2254 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2255 gate + cfg.num_local_experts * w1
2256 + cfg.num_local_experts * w2
2257 + cfg.num_local_experts * w3
2258 };
2259
2260 input_layernorm
2261 + post_attention_layernorm
2262 + q_proj
2263 + k_proj
2264 + v_proj
2265 + o_proj
2266 + moe_block
2267 };
2268 Ok(vec![
2269 per_layer_elems * dtype.size_in_bytes();
2270 cfg.num_hidden_layers
2271 ])
2272 }
2273
2274 fn num_layers(&self, config: &str) -> Result<usize> {
2275 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2276
2277 Ok(cfg.num_hidden_layers)
2278 }
2279
2280 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2281 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2282
2283 let cfg = ModelConfigMetadata {
2284 max_seq_len: cfg.max_position_embeddings,
2285 num_layers: cfg.num_hidden_layers,
2286 hidden_size: cfg.hidden_size,
2287 num_kv_heads: cfg.num_key_value_heads,
2288 num_attn_heads: cfg.num_attention_heads,
2289 sliding_window: cfg.sliding_window,
2290 k_head_dim: cfg.head_dim(),
2291 v_head_dim: cfg.head_dim(),
2292 };
2293
2294 Ok(Box::new(cfg))
2295 }
2296}
2297
2298pub struct DeepSeekV2Loader;
2302
2303impl NormalModelLoader for DeepSeekV2Loader {
2304 fn load(
2305 &self,
2306 config: &str,
2307 vb: ShardedVarBuilder,
2308 normal_loading_metadata: NormalLoadingMetadata,
2309 attention_mechanism: AttentionImplementation,
2310 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2311 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2312
2313 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2314 &cfg,
2315 vb,
2316 self.is_gptx(config)?,
2317 normal_loading_metadata,
2318 attention_mechanism,
2319 )?))
2320 }
2321 fn load_xlora(
2322 &self,
2323 _config: &str,
2324 _vb: ShardedVarBuilder,
2325 _lora_config: &[((String, String), LoraConfig)],
2326 _xlora_config: Option<XLoraConfig>,
2327 _xlora_ordering: Ordering,
2328 _normal_loading_metadata: NormalLoadingMetadata,
2329 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2330 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2331 todo!()
2332 }
2333 fn is_gptx(&self, _: &str) -> Result<bool> {
2334 Ok(true)
2335 }
2336 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2337 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2338 Ok(Box::new(cfg))
2339 }
2340}
2341
2342impl IsqModelLoader for DeepSeekV2Loader {
2343 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2344 let mut data = vec![
2345 Regex::new(r"lm_head\.(weight|bias)$")?,
2346 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2348 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2349 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2350 ];
2351 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2352 if cfg.q_lora_rank.is_some() {
2353 data.extend(vec![
2354 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2355 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2356 ]);
2357 } else {
2358 data.push(Regex::new(
2359 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2360 )?);
2361 }
2362 for layer_idx in 0..cfg.num_hidden_layers {
2363 if cfg.n_routed_experts.is_some()
2364 && layer_idx >= cfg.first_k_dense_replace
2365 && layer_idx % cfg.moe_layer_freq == 0
2366 {
2367 for i in 0..cfg.n_routed_experts.unwrap() {
2368 data.extend(vec![
2369 Regex::new(&format!(
2370 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2371 ))?,
2372 Regex::new(&format!(
2373 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2374 ))?,
2375 Regex::new(&format!(
2376 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2377 ))?,
2378 ]);
2379 }
2380 if cfg.n_shared_experts.is_some() {
2381 data.extend(vec![
2382 Regex::new(&format!(
2383 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2384 ))?,
2385 Regex::new(&format!(
2386 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2387 ))?,
2388 Regex::new(&format!(
2389 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2390 ))?,
2391 ]);
2392 }
2393 } else {
2394 data.extend(vec![
2395 Regex::new(&format!(
2396 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2397 ))?,
2398 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2399 Regex::new(&format!(
2400 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2401 ))?,
2402 ]);
2403 };
2404 }
2405 Ok(data)
2406 }
2407 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2408 self.isq_layer_regexes(config)
2409 }
2410
2411 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2412 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2413 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2414 for layer_idx in 0..cfg.num_hidden_layers {
2415 if cfg.n_routed_experts.is_some()
2416 && layer_idx >= cfg.first_k_dense_replace
2417 && layer_idx % cfg.moe_layer_freq == 0
2418 {
2419 for i in 0..cfg.n_routed_experts.unwrap() {
2420 data.extend(vec![
2421 Regex::new(&format!(
2422 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2423 ))?,
2424 Regex::new(&format!(
2425 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2426 ))?,
2427 Regex::new(&format!(
2428 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2429 ))?,
2430 ]);
2431 }
2432 if cfg.n_shared_experts.is_some() {
2433 data.extend(vec![
2434 Regex::new(&format!(
2435 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2436 ))?,
2437 Regex::new(&format!(
2438 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2439 ))?,
2440 Regex::new(&format!(
2441 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2442 ))?,
2443 ]);
2444 }
2445 } else {
2446 data.extend(vec![
2447 Regex::new(&format!(
2448 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2449 ))?,
2450 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2451 Regex::new(&format!(
2452 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2453 ))?,
2454 ]);
2455 };
2456 }
2457 Ok(data)
2458 }
2459 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2460 self.isq_layer_regexes_moqe(config)
2461 }
2462}
2463
2464impl DeviceMappedModelLoader for DeepSeekV2Loader {
2465 fn mapped_max_act_size_elems(
2466 &self,
2467 config: &str,
2468 params: &AutoDeviceMapParams,
2469 prompt_chunksize: usize,
2470 ) -> Result<usize> {
2471 let AutoDeviceMapParams::Text {
2472 max_seq_len: _,
2473 max_batch_size,
2474 } = params
2475 else {
2476 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2477 };
2478
2479 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2480
2481 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2482 }
2483 fn non_mapped_max_act_size_elems(
2484 &self,
2485 _config: &str,
2486 _params: &AutoDeviceMapParams,
2487 ) -> Result<usize> {
2488 Ok(0)
2489 }
2490
2491 fn non_mapped_size_in_bytes(
2492 &self,
2493 config: &str,
2494 dtype: DType,
2495 weight_pack_factor: usize,
2496 ) -> Result<usize> {
2497 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2498 let elems = {
2499 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2500 let lm_head = if !cfg.tie_word_embeddings {
2501 cfg.hidden_size * cfg.vocab_size
2502 } else {
2503 0
2504 };
2505 let norm = cfg.hidden_size;
2506 embed_tokens + lm_head + norm
2507 };
2508 Ok(elems * dtype.size_in_bytes())
2509 }
2510
2511 fn layer_sizes_in_bytes(
2512 &self,
2513 config: &str,
2514 dtype: DType,
2515 weight_pack_factor: usize,
2516 ) -> Result<Vec<usize>> {
2517 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2518 let mut per_layer_elems = Vec::new();
2519
2520 for layer_idx in 0..cfg.num_hidden_layers {
2521 let input_layernorm = cfg.hidden_size;
2522 let post_attention_layernorm = cfg.hidden_size;
2523
2524 let q_proj = match cfg.q_lora_rank {
2525 Some(lora_rank) => {
2526 let a = cfg.hidden_size * lora_rank;
2527 let norm = lora_rank;
2528 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2529 a + norm + b
2530 }
2531 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2532 };
2533 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2534 / weight_pack_factor
2535 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2536 let kv_a_layernorm = cfg.kv_lora_rank;
2537 let kv_b_proj = cfg.kv_lora_rank
2538 * cfg.num_attention_heads
2539 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2540 / weight_pack_factor;
2541 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2542 / weight_pack_factor
2543 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2544
2545 let moe_block = {
2546 let mut sum = 0;
2547 if cfg.n_routed_experts.is_some()
2548 && layer_idx >= cfg.first_k_dense_replace
2549 && layer_idx % cfg.moe_layer_freq == 0
2550 {
2551 let h_size = cfg.hidden_size;
2552 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2553 * cfg.n_routed_experts.unwrap();
2554 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2555 * cfg.n_routed_experts.unwrap();
2556 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2557 * cfg.n_routed_experts.unwrap();
2558 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2559 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2560 / weight_pack_factor;
2561 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2562 / weight_pack_factor;
2563 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2564 / weight_pack_factor;
2565 gate_proj + up_proj + down_proj
2566 } else {
2567 0
2568 };
2569 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2570 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2571 } else {
2572 let h_size = cfg.hidden_size;
2573 let i_size = cfg.intermediate_size;
2574 let gate_proj = h_size * i_size / weight_pack_factor;
2575 let up_proj = h_size * i_size / weight_pack_factor;
2576 let down_proj = i_size * h_size / weight_pack_factor;
2577 sum += gate_proj + up_proj + down_proj;
2578 }
2579 sum
2580 };
2581
2582 per_layer_elems.push(
2583 input_layernorm
2584 + post_attention_layernorm
2585 + q_proj
2586 + kv_a_layernorm
2587 + kv_a_proj_with_mqa
2588 + kv_b_proj
2589 + o_proj
2590 + moe_block,
2591 );
2592 }
2593
2594 Ok(per_layer_elems
2595 .into_iter()
2596 .map(|x| x * dtype.size_in_bytes())
2597 .collect())
2598 }
2599
2600 fn num_layers(&self, config: &str) -> Result<usize> {
2601 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2602 Ok(cfg.num_hidden_layers)
2603 }
2604
2605 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2606 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2607
2608 let cfg = ModelConfigMetadata {
2609 max_seq_len: cfg.max_position_embeddings,
2610 num_layers: cfg.num_hidden_layers,
2611 hidden_size: cfg.hidden_size,
2612 num_kv_heads: cfg.num_attention_heads,
2613 num_attn_heads: cfg.num_attention_heads,
2614 sliding_window: None,
2615 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2616 v_head_dim: cfg.v_head_dim,
2617 };
2618
2619 Ok(Box::new(cfg))
2620 }
2621}
2622
2623pub struct DeepSeekV3Loader;
2627
2628impl NormalModelLoader for DeepSeekV3Loader {
2629 fn load(
2630 &self,
2631 config: &str,
2632 vb: ShardedVarBuilder,
2633 normal_loading_metadata: NormalLoadingMetadata,
2634 attention_mechanism: AttentionImplementation,
2635 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2636 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2637 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2638 &cfg,
2639 vb,
2640 self.is_gptx(config)?,
2641 normal_loading_metadata,
2642 attention_mechanism,
2643 )?))
2644 }
2645 fn load_xlora(
2646 &self,
2647 _config: &str,
2648 _vb: ShardedVarBuilder,
2649 _lora_config: &[((String, String), LoraConfig)],
2650 _xlora_config: Option<XLoraConfig>,
2651 _xlora_ordering: Ordering,
2652 _normal_loading_metadata: NormalLoadingMetadata,
2653 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2654 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2655 todo!()
2656 }
2657 fn is_gptx(&self, _: &str) -> Result<bool> {
2658 Ok(true)
2659 }
2660 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2661 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2662 Ok(Box::new(cfg))
2663 }
2664}
2665
2666impl IsqModelLoader for DeepSeekV3Loader {
2667 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2668 let mut data = vec![
2669 Regex::new(r"lm_head\.(weight|bias)$")?,
2670 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2672 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2673 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2674 ];
2675 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2676 if cfg.q_lora_rank.is_some() {
2677 data.extend(vec![
2678 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2679 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2680 ]);
2681 } else {
2682 data.push(Regex::new(
2683 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2684 )?);
2685 }
2686 for layer_idx in 0..cfg.num_hidden_layers {
2687 if cfg.n_routed_experts.is_some()
2688 && layer_idx >= cfg.first_k_dense_replace
2689 && layer_idx % cfg.moe_layer_freq == 0
2690 {
2691 for i in 0..cfg.n_routed_experts.unwrap() {
2692 data.extend(vec![
2693 Regex::new(&format!(
2694 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2695 ))?,
2696 Regex::new(&format!(
2697 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2698 ))?,
2699 Regex::new(&format!(
2700 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2701 ))?,
2702 ]);
2703 }
2704 if cfg.n_shared_experts.is_some() {
2705 data.extend(vec![
2706 Regex::new(&format!(
2707 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2708 ))?,
2709 Regex::new(&format!(
2710 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2711 ))?,
2712 Regex::new(&format!(
2713 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2714 ))?,
2715 ]);
2716 }
2717 } else {
2718 data.extend(vec![
2719 Regex::new(&format!(
2720 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2721 ))?,
2722 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2723 Regex::new(&format!(
2724 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2725 ))?,
2726 ]);
2727 };
2728 }
2729 Ok(data)
2730 }
2731 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2732 self.isq_layer_regexes(config)
2733 }
2734
2735 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2736 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2737 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2738 for layer_idx in 0..cfg.num_hidden_layers {
2739 if cfg.n_routed_experts.is_some()
2740 && layer_idx >= cfg.first_k_dense_replace
2741 && layer_idx % cfg.moe_layer_freq == 0
2742 {
2743 for i in 0..cfg.n_routed_experts.unwrap() {
2744 data.extend(vec![
2745 Regex::new(&format!(
2746 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2747 ))?,
2748 Regex::new(&format!(
2749 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2750 ))?,
2751 Regex::new(&format!(
2752 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2753 ))?,
2754 ]);
2755 }
2756 if cfg.n_shared_experts.is_some() {
2757 data.extend(vec![
2758 Regex::new(&format!(
2759 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2760 ))?,
2761 Regex::new(&format!(
2762 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2763 ))?,
2764 Regex::new(&format!(
2765 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2766 ))?,
2767 ]);
2768 }
2769 } else {
2770 data.extend(vec![
2771 Regex::new(&format!(
2772 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2773 ))?,
2774 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2775 Regex::new(&format!(
2776 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2777 ))?,
2778 ]);
2779 };
2780 }
2781 Ok(data)
2782 }
2783 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2784 self.isq_layer_regexes_moqe(config)
2785 }
2786}
2787
2788impl DeviceMappedModelLoader for DeepSeekV3Loader {
2789 fn mapped_max_act_size_elems(
2790 &self,
2791 config: &str,
2792 params: &AutoDeviceMapParams,
2793 prompt_chunksize: usize,
2794 ) -> Result<usize> {
2795 let AutoDeviceMapParams::Text {
2796 max_seq_len: _,
2797 max_batch_size,
2798 } = params
2799 else {
2800 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2801 };
2802
2803 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2804
2805 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2806 }
2807 fn non_mapped_max_act_size_elems(
2808 &self,
2809 _config: &str,
2810 _params: &AutoDeviceMapParams,
2811 ) -> Result<usize> {
2812 Ok(0)
2813 }
2814
2815 fn non_mapped_size_in_bytes(
2816 &self,
2817 config: &str,
2818 dtype: DType,
2819 weight_pack_factor: usize,
2820 ) -> Result<usize> {
2821 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2822 let elems = {
2823 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2824 let lm_head = if !cfg.tie_word_embeddings {
2825 cfg.hidden_size * cfg.vocab_size
2826 } else {
2827 0
2828 };
2829 let norm = cfg.hidden_size;
2830 embed_tokens + lm_head + norm
2831 };
2832 Ok(elems * dtype.size_in_bytes())
2833 }
2834
2835 fn layer_sizes_in_bytes(
2836 &self,
2837 config: &str,
2838 dtype: DType,
2839 weight_pack_factor: usize,
2840 ) -> Result<Vec<usize>> {
2841 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2842 let mut per_layer_elems = Vec::new();
2843
2844 for layer_idx in 0..cfg.num_hidden_layers {
2845 let input_layernorm = cfg.hidden_size;
2846 let post_attention_layernorm = cfg.hidden_size;
2847
2848 let q_proj = match cfg.q_lora_rank {
2849 Some(lora_rank) => {
2850 let a = cfg.hidden_size * lora_rank;
2851 let norm = lora_rank;
2852 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2853 a + norm + b
2854 }
2855 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2856 };
2857 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2858 / weight_pack_factor
2859 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2860 let kv_a_layernorm = cfg.kv_lora_rank;
2861 let kv_b_proj = cfg.kv_lora_rank
2862 * cfg.num_attention_heads
2863 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2864 / weight_pack_factor;
2865 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2866 / weight_pack_factor
2867 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2868
2869 let moe_block = {
2870 let mut sum = 0;
2871 if cfg.n_routed_experts.is_some()
2872 && layer_idx >= cfg.first_k_dense_replace
2873 && layer_idx % cfg.moe_layer_freq == 0
2874 {
2875 let h_size = cfg.hidden_size;
2876 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2877 * cfg.n_routed_experts.unwrap();
2878 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2879 * cfg.n_routed_experts.unwrap();
2880 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2881 * cfg.n_routed_experts.unwrap();
2882 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2883 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2884 / weight_pack_factor;
2885 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2886 / weight_pack_factor;
2887 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2888 / weight_pack_factor;
2889 gate_proj + up_proj + down_proj
2890 } else {
2891 0
2892 };
2893 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2894 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2895 } else {
2896 let h_size = cfg.hidden_size;
2897 let i_size = cfg.intermediate_size;
2898 let gate_proj = h_size * i_size / weight_pack_factor;
2899 let up_proj = h_size * i_size / weight_pack_factor;
2900 let down_proj = i_size * h_size / weight_pack_factor;
2901 sum += gate_proj + up_proj + down_proj;
2902 }
2903 sum
2904 };
2905
2906 per_layer_elems.push(
2907 input_layernorm
2908 + post_attention_layernorm
2909 + q_proj
2910 + kv_a_layernorm
2911 + kv_a_proj_with_mqa
2912 + kv_b_proj
2913 + o_proj
2914 + moe_block,
2915 );
2916 }
2917
2918 Ok(per_layer_elems
2919 .into_iter()
2920 .map(|x| x * dtype.size_in_bytes())
2921 .collect())
2922 }
2923
2924 fn num_layers(&self, config: &str) -> Result<usize> {
2925 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2926 Ok(cfg.num_hidden_layers)
2927 }
2928
2929 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2930 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2931
2932 let cfg = ModelConfigMetadata {
2933 max_seq_len: cfg.max_position_embeddings,
2934 num_layers: cfg.num_hidden_layers,
2935 hidden_size: cfg.hidden_size,
2936 num_kv_heads: cfg.num_attention_heads,
2937 num_attn_heads: cfg.num_attention_heads,
2938 sliding_window: None,
2939 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2940 v_head_dim: cfg.v_head_dim,
2941 };
2942
2943 Ok(Box::new(cfg))
2944 }
2945}
2946
2947pub struct Qwen3Loader;
2951
2952impl NormalModelLoader for Qwen3Loader {
2953 fn load(
2954 &self,
2955 config: &str,
2956 vb: ShardedVarBuilder,
2957 normal_loading_metadata: NormalLoadingMetadata,
2958 attention_mechanism: AttentionImplementation,
2959 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2960 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2961
2962 Ok(Box::new(models::qwen3::Model::new(
2963 &cfg,
2964 vb,
2965 self.is_gptx(config)?,
2966 normal_loading_metadata,
2967 attention_mechanism,
2968 )?))
2969 }
2970 fn load_xlora(
2971 &self,
2972 _config: &str,
2973 _vb: ShardedVarBuilder,
2974 _lora_config: &[((String, String), LoraConfig)],
2975 _xlora_config: Option<XLoraConfig>,
2976 _xlora_ordering: Ordering,
2977 _normal_loading_metadata: NormalLoadingMetadata,
2978 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2979 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2980 todo!()
2981 }
2982 fn is_gptx(&self, _: &str) -> Result<bool> {
2983 Ok(true)
2984 }
2985 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2986 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2987
2988 Ok(Box::new(cfg))
2989 }
2990}
2991
2992impl IsqModelLoader for Qwen3Loader {
2993 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2994 Ok(vec![
2995 Regex::new(r"lm_head\.(weight|bias)$")?,
2996 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2998 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2999 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3000 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3001 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3003 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3004 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3005 ])
3006 }
3007 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3008 self.isq_layer_regexes(config)
3009 }
3010}
3011
3012impl DeviceMappedModelLoader for Qwen3Loader {
3013 fn mapped_max_act_size_elems(
3014 &self,
3015 config: &str,
3016 params: &AutoDeviceMapParams,
3017 prompt_chunksize: usize,
3018 ) -> Result<usize> {
3019 let AutoDeviceMapParams::Text {
3020 max_seq_len: _,
3021 max_batch_size,
3022 } = params
3023 else {
3024 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3025 };
3026
3027 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3028
3029 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3030 }
3031 fn non_mapped_max_act_size_elems(
3032 &self,
3033 _config: &str,
3034 _params: &AutoDeviceMapParams,
3035 ) -> Result<usize> {
3036 Ok(0)
3037 }
3038
3039 fn non_mapped_size_in_bytes(
3040 &self,
3041 config: &str,
3042 dtype: DType,
3043 weight_pack_factor: usize,
3044 ) -> Result<usize> {
3045 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3046 let elems = {
3047 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3048 let lm_head = if !cfg.tie_word_embeddings {
3049 cfg.hidden_size * cfg.vocab_size
3050 } else {
3051 0
3052 };
3053 let norm = cfg.hidden_size;
3054 embed_tokens + lm_head + norm
3055 };
3056 Ok(elems * dtype.size_in_bytes())
3057 }
3058
3059 fn layer_sizes_in_bytes(
3060 &self,
3061 config: &str,
3062 dtype: DType,
3063 weight_pack_factor: usize,
3064 ) -> Result<Vec<usize>> {
3065 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3066 let per_layer_elems = {
3067 let input_layernorm = cfg.hidden_size;
3068 let post_attention_layernorm = cfg.hidden_size;
3069
3070 let size_in = cfg.hidden_size;
3071 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3072 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3073 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3074 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3075 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3076 let o_proj = size_q * size_in / weight_pack_factor;
3077
3078 let h_size = cfg.hidden_size;
3079 let i_size = cfg.intermediate_size;
3080 let gate_proj = h_size * i_size / weight_pack_factor;
3081 let up_proj = h_size * i_size / weight_pack_factor;
3082 let down_proj = i_size * h_size / weight_pack_factor;
3083
3084 let q_norm = cfg.head_dim();
3085 let k_norm = cfg.head_dim();
3086
3087 input_layernorm
3088 + post_attention_layernorm
3089 + q_proj
3090 + k_proj
3091 + v_proj
3092 + o_proj
3093 + gate_proj
3094 + up_proj
3095 + down_proj
3096 + q_norm
3097 + k_norm
3098 };
3099 Ok(vec![
3100 per_layer_elems * dtype.size_in_bytes();
3101 cfg.num_hidden_layers
3102 ])
3103 }
3104
3105 fn num_layers(&self, config: &str) -> Result<usize> {
3106 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3107 Ok(cfg.num_hidden_layers)
3108 }
3109
3110 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3111 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3112
3113 let cfg = ModelConfigMetadata {
3114 max_seq_len: cfg.max_position_embeddings,
3115 num_layers: cfg.num_hidden_layers,
3116 hidden_size: cfg.hidden_size,
3117 num_kv_heads: cfg.num_key_value_heads,
3118 num_attn_heads: cfg.num_attention_heads,
3119 sliding_window: cfg.sliding_window,
3120 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3121 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3122 };
3123
3124 Ok(Box::new(cfg))
3125 }
3126}
3127
3128pub struct Qwen3MoELoader;
3132
3133impl NormalModelLoader for Qwen3MoELoader {
3134 fn load(
3135 &self,
3136 config: &str,
3137 vb: ShardedVarBuilder,
3138 normal_loading_metadata: NormalLoadingMetadata,
3139 attention_mechanism: AttentionImplementation,
3140 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3141 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3142
3143 Ok(Box::new(models::qwen3_moe::Model::new(
3144 &cfg,
3145 vb,
3146 self.is_gptx(config)?,
3147 normal_loading_metadata,
3148 attention_mechanism,
3149 )?))
3150 }
3151 fn load_xlora(
3152 &self,
3153 _config: &str,
3154 _vb: ShardedVarBuilder,
3155 _lora_config: &[((String, String), LoraConfig)],
3156 _xlora_config: Option<XLoraConfig>,
3157 _xlora_ordering: Ordering,
3158 _normal_loading_metadata: NormalLoadingMetadata,
3159 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3160 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3161 todo!()
3162 }
3163 fn is_gptx(&self, _: &str) -> Result<bool> {
3164 Ok(true)
3165 }
3166 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3167 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3168
3169 Ok(Box::new(cfg))
3170 }
3171}
3172
3173impl IsqModelLoader for Qwen3MoELoader {
3174 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3175 Ok(vec![
3176 Regex::new(r"lm_head\.(weight|bias)$")?,
3177 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3179 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3180 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3181 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3182 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3184 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3185 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3186 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3188 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3189 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3190 ])
3191 }
3192 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3193 self.isq_layer_regexes(config)
3194 }
3195 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3196 self.isq_layer_regexes_moqe(config)
3197 }
3198}
3199
3200impl DeviceMappedModelLoader for Qwen3MoELoader {
3201 fn mapped_max_act_size_elems(
3202 &self,
3203 config: &str,
3204 params: &AutoDeviceMapParams,
3205 prompt_chunksize: usize,
3206 ) -> Result<usize> {
3207 let AutoDeviceMapParams::Text {
3208 max_seq_len: _,
3209 max_batch_size,
3210 } = params
3211 else {
3212 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3213 };
3214
3215 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3216
3217 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3218 }
3219 fn non_mapped_max_act_size_elems(
3220 &self,
3221 _config: &str,
3222 _params: &AutoDeviceMapParams,
3223 ) -> Result<usize> {
3224 Ok(0)
3225 }
3226
3227 fn non_mapped_size_in_bytes(
3228 &self,
3229 config: &str,
3230 dtype: DType,
3231 weight_pack_factor: usize,
3232 ) -> Result<usize> {
3233 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3234 let elems = {
3235 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3236 let lm_head = if !cfg.tie_word_embeddings {
3237 cfg.hidden_size * cfg.vocab_size
3238 } else {
3239 0
3240 };
3241 let norm = cfg.hidden_size;
3242 embed_tokens + lm_head + norm
3243 };
3244 Ok(elems * dtype.size_in_bytes())
3245 }
3246
3247 fn layer_sizes_in_bytes(
3248 &self,
3249 config: &str,
3250 dtype: DType,
3251 weight_pack_factor: usize,
3252 ) -> Result<Vec<usize>> {
3253 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3254
3255 let mut layer_sizes_in_bytes = Vec::new();
3256 for layer_idx in 0..cfg.num_hidden_layers {
3257 let input_layernorm = cfg.hidden_size;
3258 let post_attention_layernorm = cfg.hidden_size;
3259
3260 let size_in = cfg.hidden_size;
3261 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3262 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3263 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3264 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3265 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3266 let o_proj = size_q * size_in / weight_pack_factor;
3267
3268 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3269 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3270 {
3271 let expert_size = {
3272 let h_size = cfg.hidden_size;
3273 let i_size = cfg.moe_intermediate_size;
3274 let gate_proj = h_size * i_size / weight_pack_factor;
3275 let up_proj = h_size * i_size / weight_pack_factor;
3276 let down_proj = i_size * h_size / weight_pack_factor;
3277 gate_proj + up_proj + down_proj
3278 };
3279 expert_size * cfg.num_experts
3280 } else {
3281 let h_size = cfg.hidden_size;
3282 let i_size = cfg.intermediate_size;
3283 let gate_proj = h_size * i_size / weight_pack_factor;
3284 let up_proj = h_size * i_size / weight_pack_factor;
3285 let down_proj = i_size * h_size / weight_pack_factor;
3286 gate_proj + up_proj + down_proj
3287 };
3288
3289 let q_norm = cfg.head_dim();
3290 let k_norm = cfg.head_dim();
3291
3292 let size_elems = input_layernorm
3293 + post_attention_layernorm
3294 + q_proj
3295 + k_proj
3296 + v_proj
3297 + o_proj
3298 + mlp_size
3299 + q_norm
3300 + k_norm;
3301
3302 let size_in_bytes = size_elems * dtype.size_in_bytes();
3303 layer_sizes_in_bytes.push(size_in_bytes);
3304 }
3305
3306 Ok(layer_sizes_in_bytes)
3307 }
3308
3309 fn num_layers(&self, config: &str) -> Result<usize> {
3310 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3311 Ok(cfg.num_hidden_layers)
3312 }
3313
3314 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3315 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3316
3317 let cfg = ModelConfigMetadata {
3318 max_seq_len: cfg.max_position_embeddings,
3319 num_layers: cfg.num_hidden_layers,
3320 hidden_size: cfg.hidden_size,
3321 num_kv_heads: cfg.num_key_value_heads,
3322 num_attn_heads: cfg.num_attention_heads,
3323 sliding_window: cfg.sliding_window,
3324 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3325 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3326 };
3327
3328 Ok(Box::new(cfg))
3329 }
3330}