1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 amoe::AnyMoeBaseModelMixin,
10 device_map::DeviceMapper,
11 layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig},
12 lora::{LoraConfig, Ordering},
13 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
14 pipeline::{
15 isq::IsqModelLoader,
16 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17 EitherCache, IsqModel,
18 },
19 serde_default_fn,
20 utils::{log::once_log_info, varbuilder_utils::DeviceForLoadTensor},
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25
26use indicatif::MultiProgress;
27use mistralrs_quant::{QuantizedConfig, ShardedVarBuilder};
28#[cfg(feature = "pyo3_macros")]
29use pyo3::pyclass;
30
31use regex::Regex;
32use serde::Deserialize;
33
34use crate::{
35 models,
36 xlora_models::{self, XLoraConfig},
37};
38
39use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
40
41pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
42 #[allow(clippy::too_many_arguments)]
43 fn forward(
44 &self,
45 input_ids: &Tensor,
46 seqlen_offsets: &[usize],
47 context_lens: Vec<(usize, usize)>,
48 position_ids: Vec<usize>,
49 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
50 flash_params: &FlashParams,
51 ) -> candle_core::Result<Tensor>;
52 #[allow(clippy::too_many_arguments)]
53 fn xlora_forward(
54 &self,
55 input_ids: &Tensor,
56 input_ids_full: &Tensor,
57 seqlen_offsets: &[usize],
58 seqlen_offsets_full: &[usize],
59 no_kv_cache: bool,
60 non_granular_state: &Option<NonGranularState>,
61 context_lens: Vec<(usize, usize)>,
62 position_ids: Vec<usize>,
63 flash_params: &FlashParams,
64 flash_params_full: &FlashParams,
65 ) -> candle_core::Result<Tensor>;
66 fn is_xlora(&self) -> bool;
67 fn device(&self) -> &Device;
68 fn cache(&self) -> &EitherCache;
69 fn cache_mut(&mut self) -> &mut EitherCache;
70 fn max_seq_len(&self) -> usize;
71 fn config(&self) -> &ModelConfigMetadata;
72}
73
74pub struct NormalLoadingMetadata {
76 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
78 pub loading_isq: bool,
80 pub real_device: Device,
82 pub multi_progress: Arc<MultiProgress>,
84}
85
86pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87 fn load(
88 &self,
89 config: &str,
90 use_flash_attn: bool,
91 vb: ShardedVarBuilder,
92 normal_loading_metadata: NormalLoadingMetadata,
93 attention_mechanism: AttentionImplementation,
94 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
95 #[allow(clippy::too_many_arguments)]
96 fn load_xlora(
97 &self,
98 config: &str,
99 use_flash_attn: bool,
100 vb: ShardedVarBuilder,
101 lora_config: &[((String, String), LoraConfig)],
102 xlora_config: Option<XLoraConfig>,
103 xlora_ordering: Ordering,
104 normal_loading_metadata: NormalLoadingMetadata,
105 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
106 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
107 fn is_gptx(&self, config: &str) -> Result<bool>;
108 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
109 fn get_device_for_tensor(
110 &self,
111 config: &str,
112 _mapper: &dyn DeviceMapper,
113 loading_isq: bool,
114 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115 if loading_isq {
116 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117 } else {
118 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119 let num_layers = self.model_config(config)?.num_layers();
120 let closure = move |name: String| {
121 if let Some(captures) = re.captures(&name) {
122 captures
123 .get(1)
124 .and_then(|m| m.as_str().parse::<usize>().ok())
125 .map(|l| l.min(num_layers))
126 .map(DeviceForLoadTensor::Idx)
127 .unwrap_or(DeviceForLoadTensor::Base)
128 } else {
129 DeviceForLoadTensor::Base
130 }
131 };
132
133 Ok(Arc::new(closure))
134 }
135 }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140pub enum NormalLoaderType {
142 #[serde(rename = "mistral")]
143 Mistral,
144 #[serde(rename = "gemma")]
145 Gemma,
146 #[serde(rename = "mixtral")]
147 Mixtral,
148 #[serde(rename = "llama")]
149 Llama,
150 #[serde(rename = "phi2")]
151 Phi2,
152 #[serde(rename = "phi3")]
153 Phi3,
154 #[serde(rename = "qwen2")]
155 Qwen2,
156 #[serde(rename = "gemma2")]
157 Gemma2,
158 #[serde(rename = "starcoder2")]
159 Starcoder2,
160 #[serde(rename = "phi3.5moe")]
161 Phi3_5MoE,
162 #[serde(rename = "deepseekv2")]
163 DeepSeekV2,
164 #[serde(rename = "deepseekv3")]
165 DeepSeekV3,
166}
167
168impl NormalLoaderType {
171 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
172 match name {
173 "MistralForCausalLM" => Ok(Self::Mistral),
174 "MixtralForCausalLM" => Ok(Self::Mixtral),
175 "GemmaForCausalLM" => Ok(Self::Gemma),
176 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
177 "PhiForCausalLM" => Ok(Self::Phi2),
178 "Phi3ForCausalLM" => Ok(Self::Phi3),
179 "LlamaForCausalLM" => Ok(Self::Llama),
180 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
181 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
182 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
183 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
184 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
185 other => anyhow::bail!(
186 "Unsupported Huggging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
187 ),
188 }
189 }
190}
191
192impl FromStr for NormalLoaderType {
193 type Err = String;
194 fn from_str(s: &str) -> Result<Self, Self::Err> {
195 match s {
196 "mistral" => Ok(Self::Mistral),
197 "gemma" => Ok(Self::Gemma),
198 "mixtral" => Ok(Self::Mixtral),
199 "llama" => Ok(Self::Llama),
200 "phi2" => Ok(Self::Phi2),
201 "phi3" => Ok(Self::Phi3),
202 "qwen2" => Ok(Self::Qwen2),
203 "gemma2" => Ok(Self::Gemma2),
204 "starcoder2" => Ok(Self::Starcoder2),
205 "phi3.5moe" => Ok(Self::Phi3_5MoE),
206 "deepseekv2" => Ok(Self::DeepSeekV2),
207 "deepseekv3" => Ok(Self::DeepSeekV3),
208 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`.")),
209 }
210 }
211}
212
213impl Display for NormalLoaderType {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 match self {
216 Self::Gemma => write!(f, "gemma"),
217 Self::Gemma2 => write!(f, "gemma2"),
218 Self::Llama => write!(f, "llama"),
219 Self::Mistral => write!(f, "mistral"),
220 Self::Mixtral => write!(f, "mixtral"),
221 Self::Phi2 => write!(f, "phi2"),
222 Self::Phi3 => write!(f, "phi3"),
223 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
224 Self::Qwen2 => write!(f, "qwen2"),
225 Self::Starcoder2 => write!(f, "starcoder2"),
226 Self::DeepSeekV2 => write!(f, "deepseekv2"),
227 Self::DeepSeekV3 => write!(f, "deepseekv3"),
228 }
229 }
230}
231
232macro_rules! bias_if {
233 ($cond:expr, $size:expr) => {
234 if $cond {
235 $size
236 } else {
237 0
238 }
239 };
240}
241
242pub struct AutoLoader;
244
245#[derive(Deserialize)]
246struct AutoLoaderConfig {
247 architectures: Vec<String>,
248}
249
250impl AutoLoader {
251 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
252 let auto_cfg: AutoLoaderConfig = serde_json::from_str(config)?;
253 if auto_cfg.architectures.len() != 1 {
254 anyhow::bail!("Expected to have one name for `architectures` config field.")
255 }
256
257 let name = &auto_cfg.architectures[0];
258
259 let tp = NormalLoaderType::from_causal_lm_name(name)?;
260
261 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
262
263 match tp {
264 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
265 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
266 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
267 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
268 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
269 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
270 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
271 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
272 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
273 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
274 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
275 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
276 }
277 }
278}
279
280impl NormalModelLoader for AutoLoader {
281 fn load(
282 &self,
283 config: &str,
284 use_flash_attn: bool,
285 vb: ShardedVarBuilder,
286 normal_loading_metadata: NormalLoadingMetadata,
287 attention_mechanism: AttentionImplementation,
288 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
289 Self::get_loader(config)?.load(
290 config,
291 use_flash_attn,
292 vb,
293 normal_loading_metadata,
294 attention_mechanism,
295 )
296 }
297 fn load_xlora(
298 &self,
299 config: &str,
300 use_flash_attn: bool,
301 vb: ShardedVarBuilder,
302 lora_config: &[((String, String), LoraConfig)],
303 xlora_config: Option<XLoraConfig>,
304 xlora_ordering: Ordering,
305 normal_loading_metadata: NormalLoadingMetadata,
306 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
307 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
308 Self::get_loader(config)?.load_xlora(
309 config,
310 use_flash_attn,
311 vb,
312 lora_config,
313 xlora_config,
314 xlora_ordering,
315 normal_loading_metadata,
316 preload_adapters,
317 )
318 }
319 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
320 Self::get_loader(config)?.get_config_repr(config, use_flash_attn)
321 }
322 fn is_gptx(&self, config: &str) -> Result<bool> {
323 Self::get_loader(config)?.is_gptx(config)
324 }
325}
326
327impl IsqModelLoader for AutoLoader {
328 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
329 Self::get_loader(config)?.isq_layer_regexes(config)
330 }
331}
332
333impl DeviceMappedModelLoader for AutoLoader {
334 fn non_mapped_size_in_bytes(
335 &self,
336 config: &str,
337 dtype: DType,
338 weight_pack_factor: usize,
339 ) -> Result<usize> {
340 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
341 }
342 fn num_layers(&self, config: &str) -> Result<usize> {
343 Self::get_loader(config)?.num_layers(config)
344 }
345 fn layer_sizes_in_bytes(
346 &self,
347 config: &str,
348 dtype: DType,
349 weight_pack_factor: usize,
350 ) -> Result<Vec<usize>> {
351 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
352 }
353 fn mapped_max_act_size_elems(
354 &self,
355 config: &str,
356 params: &super::AutoDeviceMapParams,
357 prompt_chunksize: usize,
358 ) -> Result<usize> {
359 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
360 }
361 fn non_mapped_max_act_size_elems(
362 &self,
363 _config: &str,
364 _params: &AutoDeviceMapParams,
365 ) -> Result<usize> {
366 Ok(0)
367 }
368 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
369 Self::get_loader(config)?.model_config(config)
370 }
371}
372
373serde_default_fn!(bool, word_emb_default, false);
374
375#[derive(Deserialize, Debug)]
378struct MistralBasicConfig {
379 vocab_size: usize,
380 hidden_size: usize,
381 intermediate_size: usize,
382 num_hidden_layers: usize,
383 num_attention_heads: usize,
384 num_key_value_heads: usize,
385 hidden_act: Activation,
386 max_position_embeddings: usize,
387 rms_norm_eps: f64,
388 rope_theta: f64,
389 sliding_window: Option<usize>,
390 head_dim: Option<usize>,
391 quantization_config: Option<QuantizedConfig>,
392 #[serde(default = "word_emb_default")]
393 tie_word_embeddings: bool,
394}
395
396impl MistralBasicConfig {
397 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mistral::Config> {
398 let basic_config: Self = serde_json::from_str(slice)?;
399 Ok(models::mistral::Config {
400 vocab_size: basic_config.vocab_size,
401 hidden_size: basic_config.hidden_size,
402 intermediate_size: basic_config.intermediate_size,
403 num_hidden_layers: basic_config.num_hidden_layers,
404 num_attention_heads: basic_config.num_attention_heads,
405 num_key_value_heads: basic_config.num_key_value_heads,
406 hidden_act: basic_config.hidden_act,
407 max_position_embeddings: basic_config.max_position_embeddings,
408 rms_norm_eps: basic_config.rms_norm_eps,
409 rope_theta: basic_config.rope_theta,
410 sliding_window: basic_config.sliding_window,
411 use_flash_attn,
412 head_dim: basic_config.head_dim,
413 quantization_config: basic_config.quantization_config,
414 tie_word_embeddings: basic_config.tie_word_embeddings,
415 })
416 }
417}
418
419pub struct MistralLoader;
420
421impl NormalModelLoader for MistralLoader {
422 fn load(
423 &self,
424 config: &str,
425 use_flash_attn: bool,
426 vb: ShardedVarBuilder,
427 normal_loading_metadata: NormalLoadingMetadata,
428 attention_mechanism: AttentionImplementation,
429 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
430 Ok(Box::new(models::mistral::Model::new(
431 &MistralBasicConfig::deserialize(config, use_flash_attn)?,
432 vb,
433 self.is_gptx(config)?,
434 normal_loading_metadata,
435 attention_mechanism,
436 )?))
437 }
438 fn load_xlora(
439 &self,
440 config: &str,
441 use_flash_attn: bool,
442 vb: ShardedVarBuilder,
443 lora_config: &[((String, String), LoraConfig)],
444 xlora_config: Option<XLoraConfig>,
445 xlora_ordering: Ordering,
446 normal_loading_metadata: NormalLoadingMetadata,
447 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
448 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
449 Ok(Box::new(xlora_models::XLoraMistral::new(
450 &MistralBasicConfig::deserialize(config, use_flash_attn)?,
451 vb,
452 lora_config,
453 xlora_config,
454 xlora_ordering,
455 self.is_gptx(config)?,
456 normal_loading_metadata,
457 preload_adapters,
458 )?))
459 }
460 fn is_gptx(&self, _: &str) -> Result<bool> {
461 Ok(true)
462 }
463 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
464 Ok(Box::new(MistralBasicConfig::deserialize(
465 config,
466 use_flash_attn,
467 )?))
468 }
469}
470
471impl IsqModelLoader for MistralLoader {
472 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
473 Ok(vec![
474 Regex::new(r"lm_head\.(weight|bias)$")?,
475 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
477 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
478 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
479 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
480 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
482 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
483 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
484 ])
485 }
486}
487
488impl DeviceMappedModelLoader for MistralLoader {
489 fn mapped_max_act_size_elems(
490 &self,
491 config: &str,
492 params: &AutoDeviceMapParams,
493 prompt_chunksize: usize,
494 ) -> Result<usize> {
495 let AutoDeviceMapParams::Text {
496 max_seq_len: _,
497 max_batch_size,
498 } = params
499 else {
500 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
501 };
502
503 let cfg = MistralBasicConfig::deserialize(config, false)?;
504
505 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
506 }
507 fn non_mapped_max_act_size_elems(
508 &self,
509 _config: &str,
510 _params: &AutoDeviceMapParams,
511 ) -> Result<usize> {
512 Ok(0)
513 }
514
515 fn non_mapped_size_in_bytes(
516 &self,
517 config: &str,
518 dtype: DType,
519 weight_pack_factor: usize,
520 ) -> Result<usize> {
521 let cfg = MistralBasicConfig::deserialize(config, false)?;
522 let elems = {
523 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
524 let lm_head = if !cfg.tie_word_embeddings {
525 cfg.hidden_size * cfg.vocab_size
526 } else {
527 0
528 };
529 let norm = cfg.hidden_size;
530 embed_tokens + lm_head + norm
531 };
532 Ok(elems * dtype.size_in_bytes())
533 }
534
535 fn layer_sizes_in_bytes(
536 &self,
537 config: &str,
538 dtype: DType,
539 weight_pack_factor: usize,
540 ) -> Result<Vec<usize>> {
541 let cfg = MistralBasicConfig::deserialize(config, false)?;
542 let per_layer_elems = {
543 let input_layernorm = cfg.hidden_size;
544 let post_attention_layernorm = cfg.hidden_size;
545
546 let size_in = cfg.hidden_size;
547 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
548 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
549 let q_proj = size_in * size_q / weight_pack_factor;
550 let k_proj = size_in * size_kv / weight_pack_factor;
551 let v_proj = size_in * size_kv / weight_pack_factor;
552 let o_proj = size_q * size_in / weight_pack_factor;
553
554 let h_size = cfg.hidden_size;
555 let i_size = cfg.intermediate_size;
556 let gate_proj = h_size * i_size / weight_pack_factor;
557 let up_proj = h_size * i_size / weight_pack_factor;
558 let down_proj = i_size * h_size / weight_pack_factor;
559
560 input_layernorm
561 + post_attention_layernorm
562 + q_proj
563 + k_proj
564 + v_proj
565 + o_proj
566 + gate_proj
567 + up_proj
568 + down_proj
569 };
570 Ok(vec![
571 per_layer_elems * dtype.size_in_bytes();
572 cfg.num_hidden_layers
573 ])
574 }
575
576 fn num_layers(&self, config: &str) -> Result<usize> {
577 let cfg = MistralBasicConfig::deserialize(config, false)?;
578 Ok(cfg.num_hidden_layers)
579 }
580
581 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
582 let cfg = MistralBasicConfig::deserialize(config, false)?;
583
584 let cfg = ModelConfigMetadata {
585 max_seq_len: cfg.max_position_embeddings,
586 num_layers: cfg.num_hidden_layers,
587 hidden_size: cfg.hidden_size,
588 num_kv_heads: cfg.num_key_value_heads,
589 num_attn_heads: cfg.num_attention_heads,
590 sliding_window: cfg.sliding_window,
591 k_head_dim: cfg.head_dim(),
592 v_head_dim: cfg.head_dim(),
593 };
594
595 Ok(Box::new(cfg))
596 }
597}
598
599fn default_max_position_embeddings() -> usize {
602 4096
603}
604
605#[derive(Deserialize)]
606struct GemmaBasicConfig {
607 attention_bias: bool,
608 head_dim: usize,
609 hidden_act: Option<Activation>,
611 hidden_activation: Option<Activation>,
612 hidden_size: usize,
613 intermediate_size: usize,
614 num_attention_heads: usize,
615 num_hidden_layers: usize,
616 num_key_value_heads: usize,
617 rms_norm_eps: f64,
618 rope_theta: f64,
619 vocab_size: usize,
620
621 #[serde(default = "default_max_position_embeddings")]
622 max_position_embeddings: usize,
623 quantization_config: Option<QuantizedConfig>,
624 #[serde(default = "word_emb_default")]
625 tie_word_embeddings: bool,
626}
627
628impl GemmaBasicConfig {
629 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma::Config> {
630 let basic_config: Self = serde_json::from_str(slice)?;
631 Ok(models::gemma::Config {
632 vocab_size: basic_config.vocab_size,
633 hidden_size: basic_config.hidden_size,
634 intermediate_size: basic_config.intermediate_size,
635 num_hidden_layers: basic_config.num_hidden_layers,
636 num_attention_heads: basic_config.num_attention_heads,
637 num_key_value_heads: basic_config.num_key_value_heads,
638 hidden_act: basic_config.hidden_act,
639 hidden_activation: basic_config.hidden_activation,
640 max_position_embeddings: basic_config.max_position_embeddings,
641 rms_norm_eps: basic_config.rms_norm_eps,
642 rope_theta: basic_config.rope_theta,
643 attention_bias: basic_config.attention_bias,
644 head_dim: basic_config.head_dim,
645 use_flash_attn,
646 quantization_config: basic_config.quantization_config,
647 tie_word_embeddings: basic_config.tie_word_embeddings,
648 })
649 }
650}
651
652pub struct GemmaLoader;
656
657impl NormalModelLoader for GemmaLoader {
658 fn load(
659 &self,
660 config: &str,
661 use_flash_attn: bool,
662 vb: ShardedVarBuilder,
663 normal_loading_metadata: NormalLoadingMetadata,
664 attention_mechanism: AttentionImplementation,
665 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
666 Ok(Box::new(models::gemma::Model::new(
667 &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
668 vb,
669 self.is_gptx(config)?,
670 normal_loading_metadata,
671 attention_mechanism,
672 )?))
673 }
674 fn load_xlora(
675 &self,
676 config: &str,
677 use_flash_attn: bool,
678 vb: ShardedVarBuilder,
679 lora_config: &[((String, String), LoraConfig)],
680 xlora_config: Option<XLoraConfig>,
681 xlora_ordering: Ordering,
682 normal_loading_metadata: NormalLoadingMetadata,
683 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
684 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
685 Ok(Box::new(xlora_models::XLoraGemma::new(
686 &GemmaBasicConfig::deserialize(config, use_flash_attn)?,
687 vb,
688 lora_config,
689 xlora_config,
690 xlora_ordering,
691 self.is_gptx(config)?,
692 normal_loading_metadata,
693 preload_adapters,
694 )?))
695 }
696 fn is_gptx(&self, _: &str) -> Result<bool> {
697 Ok(true)
698 }
699 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
700 Ok(Box::new(GemmaBasicConfig::deserialize(
701 config,
702 use_flash_attn,
703 )?))
704 }
705}
706
707impl IsqModelLoader for GemmaLoader {
708 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
709 Ok(vec![
710 Regex::new(r"lm_head\.(weight|bias)$")?,
711 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
713 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
714 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
715 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
716 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
718 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
719 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
720 ])
721 }
722}
723
724impl DeviceMappedModelLoader for GemmaLoader {
725 fn mapped_max_act_size_elems(
726 &self,
727 config: &str,
728 params: &AutoDeviceMapParams,
729 prompt_chunksize: usize,
730 ) -> Result<usize> {
731 let AutoDeviceMapParams::Text {
732 max_seq_len: _,
733 max_batch_size,
734 } = params
735 else {
736 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
737 };
738
739 let cfg = GemmaBasicConfig::deserialize(config, false)?;
740
741 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
742 }
743 fn non_mapped_max_act_size_elems(
744 &self,
745 _config: &str,
746 _params: &AutoDeviceMapParams,
747 ) -> Result<usize> {
748 Ok(0)
749 }
750
751 fn non_mapped_size_in_bytes(
752 &self,
753 config: &str,
754 dtype: DType,
755 weight_pack_factor: usize,
756 ) -> Result<usize> {
757 let cfg = GemmaBasicConfig::deserialize(config, false)?;
758 let elems = {
759 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
760 let lm_head = if !cfg.tie_word_embeddings {
761 cfg.hidden_size * cfg.vocab_size
762 } else {
763 0
764 };
765 let norm = cfg.hidden_size;
766 embed_tokens + lm_head + norm
767 };
768 Ok(elems * dtype.size_in_bytes())
769 }
770
771 fn layer_sizes_in_bytes(
772 &self,
773 config: &str,
774 dtype: DType,
775 weight_pack_factor: usize,
776 ) -> Result<Vec<usize>> {
777 let cfg = GemmaBasicConfig::deserialize(config, false)?;
778 let per_layer_elems = {
779 let input_layernorm = cfg.hidden_size;
780 let post_attention_layernorm = cfg.hidden_size;
781
782 let size_in = cfg.hidden_size;
783 let size_q = cfg.head_dim * cfg.num_attention_heads;
784 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
785 let q_proj =
786 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
787 let k_proj =
788 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
789 let v_proj =
790 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
791 let o_proj =
792 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
793
794 let h_size = cfg.hidden_size;
795 let i_size = cfg.intermediate_size;
796 let gate_proj = h_size * i_size / weight_pack_factor;
797 let up_proj = h_size * i_size / weight_pack_factor;
798 let down_proj = i_size * h_size / weight_pack_factor;
799
800 input_layernorm
801 + post_attention_layernorm
802 + q_proj
803 + k_proj
804 + v_proj
805 + o_proj
806 + gate_proj
807 + up_proj
808 + down_proj
809 };
810 Ok(vec![
811 per_layer_elems * dtype.size_in_bytes();
812 cfg.num_hidden_layers
813 ])
814 }
815
816 fn num_layers(&self, config: &str) -> Result<usize> {
817 let cfg = GemmaBasicConfig::deserialize(config, false)?;
818 Ok(cfg.num_hidden_layers)
819 }
820
821 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
822 let cfg = GemmaBasicConfig::deserialize(config, false)?;
823
824 let cfg = ModelConfigMetadata {
825 max_seq_len: cfg.max_position_embeddings,
826 num_layers: cfg.num_hidden_layers,
827 hidden_size: cfg.hidden_size,
828 num_kv_heads: cfg.num_key_value_heads,
829 num_attn_heads: cfg.num_attention_heads,
830 sliding_window: None,
831 k_head_dim: cfg.head_dim,
832 v_head_dim: cfg.head_dim,
833 };
834
835 Ok(Box::new(cfg))
836 }
837}
838
839#[derive(Deserialize)]
842struct LlamaBasicConfig {
843 hidden_act: Activation,
844 hidden_size: usize,
845 intermediate_size: usize,
846 vocab_size: usize,
847 num_hidden_layers: usize,
848 num_attention_heads: usize,
849 num_key_value_heads: Option<usize>,
850 rms_norm_eps: f64,
851 #[serde(default = "default_rope")]
852 rope_theta: f32,
853 max_position_embeddings: usize,
854 rope_scaling: Option<Llama3RopeConfig>,
855 quantization_config: Option<QuantizedConfig>,
856 #[serde(default = "word_emb_default")]
857 tie_word_embeddings: bool,
858}
859
860fn default_rope() -> f32 {
861 10_000.0
862}
863
864impl LlamaBasicConfig {
865 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::llama::Config> {
866 let basic_config: Self = serde_json::from_str(slice)?;
867 Ok(models::llama::Config {
868 hidden_size: basic_config.hidden_size,
869 intermediate_size: basic_config.intermediate_size,
870 vocab_size: basic_config.vocab_size,
871 num_hidden_layers: basic_config.num_hidden_layers,
872 num_attention_heads: basic_config.num_attention_heads,
873 num_key_value_heads: basic_config
874 .num_key_value_heads
875 .unwrap_or(basic_config.num_attention_heads),
876 rms_norm_eps: basic_config.rms_norm_eps,
877 rope_theta: basic_config.rope_theta,
878 use_flash_attn,
879 max_position_embeddings: basic_config.max_position_embeddings,
880 rope_scaling: basic_config.rope_scaling,
881 quantization_config: basic_config.quantization_config,
882 tie_word_embeddings: basic_config.tie_word_embeddings,
883 hidden_act: basic_config.hidden_act,
884 })
885 }
886}
887
888pub struct LlamaLoader;
892
893impl NormalModelLoader for LlamaLoader {
894 fn load(
895 &self,
896 config: &str,
897 use_flash_attn: bool,
898 vb: ShardedVarBuilder,
899 normal_loading_metadata: NormalLoadingMetadata,
900 attention_mechanism: AttentionImplementation,
901 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
902 Ok(Box::new(models::llama::Llama::new(
903 &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
904 vb,
905 self.is_gptx(config)?,
906 normal_loading_metadata,
907 attention_mechanism,
908 )?))
909 }
910 fn load_xlora(
911 &self,
912 config: &str,
913 use_flash_attn: bool,
914 vb: ShardedVarBuilder,
915 lora_config: &[((String, String), LoraConfig)],
916 xlora_config: Option<XLoraConfig>,
917 xlora_ordering: Ordering,
918 normal_loading_metadata: NormalLoadingMetadata,
919 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
920 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
921 Ok(Box::new(xlora_models::XLoraLlama::new(
922 &LlamaBasicConfig::deserialize(config, use_flash_attn)?,
923 vb,
924 lora_config,
925 xlora_config,
926 xlora_ordering,
927 self.is_gptx(config)?,
928 normal_loading_metadata,
929 preload_adapters,
930 )?))
931 }
932 fn is_gptx(&self, _: &str) -> Result<bool> {
933 Ok(true)
934 }
935 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
936 Ok(Box::new(LlamaBasicConfig::deserialize(
937 config,
938 use_flash_attn,
939 )?))
940 }
941}
942
943impl IsqModelLoader for LlamaLoader {
944 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
945 Ok(vec![
946 Regex::new(r"lm_head\.(weight|bias)$")?,
947 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
949 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
950 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
951 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
952 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
954 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
955 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
956 ])
957 }
958}
959
960impl DeviceMappedModelLoader for LlamaLoader {
961 fn mapped_max_act_size_elems(
962 &self,
963 config: &str,
964 params: &AutoDeviceMapParams,
965 prompt_chunksize: usize,
966 ) -> Result<usize> {
967 let AutoDeviceMapParams::Text {
968 max_seq_len: _,
969 max_batch_size,
970 } = params
971 else {
972 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
973 };
974
975 let cfg = LlamaBasicConfig::deserialize(config, false)?;
976
977 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
978 }
979 fn non_mapped_max_act_size_elems(
980 &self,
981 _config: &str,
982 _params: &AutoDeviceMapParams,
983 ) -> Result<usize> {
984 Ok(0)
985 }
986
987 fn non_mapped_size_in_bytes(
988 &self,
989 config: &str,
990 dtype: DType,
991 weight_pack_factor: usize,
992 ) -> Result<usize> {
993 let cfg = LlamaBasicConfig::deserialize(config, false)?;
994 let elems = {
995 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
996 let lm_head = if !cfg.tie_word_embeddings {
997 cfg.hidden_size * cfg.vocab_size
998 } else {
999 0
1000 };
1001 let norm = cfg.hidden_size;
1002 embed_tokens + lm_head + norm
1003 };
1004 Ok(elems * dtype.size_in_bytes())
1005 }
1006
1007 fn layer_sizes_in_bytes(
1008 &self,
1009 config: &str,
1010 dtype: DType,
1011 weight_pack_factor: usize,
1012 ) -> Result<Vec<usize>> {
1013 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1014 let per_layer_elems = {
1015 let input_layernorm = cfg.hidden_size;
1016 let post_attention_layernorm = cfg.hidden_size;
1017
1018 let size_in = cfg.hidden_size;
1019 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1020 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1021 let q_proj = size_in * size_q / weight_pack_factor;
1022 let k_proj = size_in * size_kv / weight_pack_factor;
1023 let v_proj = size_in * size_kv / weight_pack_factor;
1024 let o_proj = size_q * size_in / weight_pack_factor;
1025
1026 let h_size = cfg.hidden_size;
1027 let i_size = cfg.intermediate_size;
1028 let gate_proj = h_size * i_size / weight_pack_factor;
1029 let up_proj = h_size * i_size / weight_pack_factor;
1030 let down_proj = i_size * h_size / weight_pack_factor;
1031
1032 input_layernorm
1033 + post_attention_layernorm
1034 + q_proj
1035 + k_proj
1036 + v_proj
1037 + o_proj
1038 + gate_proj
1039 + up_proj
1040 + down_proj
1041 };
1042 Ok(vec![
1043 per_layer_elems * dtype.size_in_bytes();
1044 cfg.num_hidden_layers
1045 ])
1046 }
1047
1048 fn num_layers(&self, config: &str) -> Result<usize> {
1049 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1050 Ok(cfg.num_hidden_layers)
1051 }
1052 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1053 let cfg = LlamaBasicConfig::deserialize(config, false)?;
1054
1055 let cfg = ModelConfigMetadata {
1056 max_seq_len: cfg.max_position_embeddings,
1057 num_layers: cfg.num_hidden_layers,
1058 hidden_size: cfg.hidden_size,
1059 num_kv_heads: cfg.num_key_value_heads,
1060 num_attn_heads: cfg.num_attention_heads,
1061 sliding_window: None,
1062 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1063 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1064 };
1065
1066 Ok(Box::new(cfg))
1067 }
1068}
1069
1070#[derive(Deserialize)]
1073struct MixtralBasicConfig {
1074 vocab_size: usize,
1075 hidden_size: usize,
1076 intermediate_size: usize,
1077 num_hidden_layers: usize,
1078 num_attention_heads: usize,
1079 num_key_value_heads: usize,
1080 hidden_act: Activation,
1081 max_position_embeddings: usize,
1082 rms_norm_eps: f64,
1083 rope_theta: f64,
1084 sliding_window: Option<usize>,
1085 num_experts_per_tok: usize,
1086 num_local_experts: usize,
1087 quantization_config: Option<QuantizedConfig>,
1088 #[serde(default = "word_emb_default")]
1089 tie_word_embeddings: bool,
1090}
1091
1092impl MixtralBasicConfig {
1093 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::mixtral::Config> {
1094 let basic_config: Self = serde_json::from_str(slice)?;
1095 Ok(models::mixtral::Config {
1096 vocab_size: basic_config.vocab_size,
1097 hidden_size: basic_config.hidden_size,
1098 intermediate_size: basic_config.intermediate_size,
1099 num_hidden_layers: basic_config.num_hidden_layers,
1100 num_attention_heads: basic_config.num_attention_heads,
1101 num_key_value_heads: basic_config.num_key_value_heads,
1102 hidden_act: basic_config.hidden_act,
1103 max_position_embeddings: basic_config.max_position_embeddings,
1104 rms_norm_eps: basic_config.rms_norm_eps,
1105 rope_theta: basic_config.rope_theta,
1106 sliding_window: basic_config.sliding_window,
1107 use_flash_attn,
1108 num_experts_per_tok: basic_config.num_experts_per_tok,
1109 num_local_experts: basic_config.num_local_experts,
1110 quantization_config: basic_config.quantization_config,
1111 tie_word_embeddings: basic_config.tie_word_embeddings,
1112 })
1113 }
1114}
1115
1116pub struct MixtralLoader;
1117
1118impl NormalModelLoader for MixtralLoader {
1119 fn load(
1120 &self,
1121 config: &str,
1122 use_flash_attn: bool,
1123 vb: ShardedVarBuilder,
1124 normal_loading_metadata: NormalLoadingMetadata,
1125 attention_mechanism: AttentionImplementation,
1126 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1127 Ok(Box::new(models::mixtral::Model::new(
1128 &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1129 vb,
1130 self.is_gptx(config)?,
1131 normal_loading_metadata,
1132 attention_mechanism,
1133 )?))
1134 }
1135 fn load_xlora(
1136 &self,
1137 config: &str,
1138 use_flash_attn: bool,
1139 vb: ShardedVarBuilder,
1140 lora_config: &[((String, String), LoraConfig)],
1141 xlora_config: Option<XLoraConfig>,
1142 xlora_ordering: Ordering,
1143 normal_loading_metadata: NormalLoadingMetadata,
1144 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1145 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1146 Ok(Box::new(xlora_models::XLoraMixtral::new(
1147 &MixtralBasicConfig::deserialize(config, use_flash_attn)?,
1148 vb,
1149 lora_config,
1150 xlora_config,
1151 xlora_ordering,
1152 self.is_gptx(config)?,
1153 normal_loading_metadata,
1154 preload_adapters,
1155 )?))
1156 }
1157 fn is_gptx(&self, _: &str) -> Result<bool> {
1158 Ok(true)
1159 }
1160 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1161 Ok(Box::new(MixtralBasicConfig::deserialize(
1162 config,
1163 use_flash_attn,
1164 )?))
1165 }
1166}
1167
1168impl IsqModelLoader for MixtralLoader {
1169 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1170 Ok(vec![
1171 Regex::new(r"lm_head\.(weight|bias)$")?,
1172 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1174 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1175 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1176 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1177 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1179 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1180 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1181 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1182 ])
1183 }
1184}
1185
1186impl DeviceMappedModelLoader for MixtralLoader {
1187 fn mapped_max_act_size_elems(
1188 &self,
1189 config: &str,
1190 params: &AutoDeviceMapParams,
1191 prompt_chunksize: usize,
1192 ) -> Result<usize> {
1193 let AutoDeviceMapParams::Text {
1194 max_seq_len: _,
1195 max_batch_size,
1196 } = params
1197 else {
1198 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1199 };
1200
1201 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1202
1203 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1204 }
1205 fn non_mapped_max_act_size_elems(
1206 &self,
1207 _config: &str,
1208 _params: &AutoDeviceMapParams,
1209 ) -> Result<usize> {
1210 Ok(0)
1211 }
1212
1213 fn non_mapped_size_in_bytes(
1214 &self,
1215 config: &str,
1216 dtype: DType,
1217 weight_pack_factor: usize,
1218 ) -> Result<usize> {
1219 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1220 let elems = {
1221 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1222 let lm_head = if !cfg.tie_word_embeddings {
1223 cfg.hidden_size * cfg.vocab_size
1224 } else {
1225 0
1226 };
1227 let norm = cfg.hidden_size;
1228 embed_tokens + lm_head + norm
1229 };
1230 Ok(elems * dtype.size_in_bytes())
1231 }
1232
1233 fn layer_sizes_in_bytes(
1234 &self,
1235 config: &str,
1236 dtype: DType,
1237 weight_pack_factor: usize,
1238 ) -> Result<Vec<usize>> {
1239 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1240 let per_layer_elems = {
1241 let input_layernorm = cfg.hidden_size;
1242 let post_attention_layernorm = cfg.hidden_size;
1243
1244 let size_in = cfg.hidden_size;
1245 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1246 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1247 let q_proj = size_in * size_q / weight_pack_factor;
1248 let k_proj = size_in * size_kv / weight_pack_factor;
1249 let v_proj = size_in * size_kv / weight_pack_factor;
1250 let o_proj = size_q * size_in / weight_pack_factor;
1251
1252 let moe_block = {
1253 let gate = cfg.hidden_size * cfg.num_local_experts;
1254 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1256 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1257 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1258 gate + cfg.num_local_experts * w1
1259 + cfg.num_local_experts * w2
1260 + cfg.num_local_experts * w3
1261 };
1262
1263 input_layernorm
1264 + post_attention_layernorm
1265 + q_proj
1266 + k_proj
1267 + v_proj
1268 + o_proj
1269 + moe_block
1270 };
1271 Ok(vec![
1272 per_layer_elems * dtype.size_in_bytes();
1273 cfg.num_hidden_layers
1274 ])
1275 }
1276
1277 fn num_layers(&self, config: &str) -> Result<usize> {
1278 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1279 Ok(cfg.num_hidden_layers)
1280 }
1281
1282 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1283 let cfg = MixtralBasicConfig::deserialize(config, false)?;
1284
1285 let cfg = ModelConfigMetadata {
1286 max_seq_len: cfg.max_position_embeddings,
1287 num_layers: cfg.num_hidden_layers,
1288 hidden_size: cfg.hidden_size,
1289 num_kv_heads: cfg.num_key_value_heads,
1290 num_attn_heads: cfg.num_attention_heads,
1291 sliding_window: cfg.sliding_window,
1292 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1293 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1294 };
1295
1296 Ok(Box::new(cfg))
1297 }
1298}
1299
1300#[derive(Deserialize)]
1303struct Phi2BasicConfig {
1304 vocab_size: usize,
1305 hidden_size: usize,
1306 intermediate_size: usize,
1307 num_hidden_layers: usize,
1308 num_attention_heads: usize,
1309 num_key_value_heads: Option<usize>,
1310 hidden_act: Activation,
1311 max_position_embeddings: usize,
1312 layer_norm_eps: f64,
1313 rope_theta: f32,
1314 partial_rotary_factor: f64,
1315 qk_layernorm: bool,
1316 quantization_config: Option<QuantizedConfig>,
1317 #[serde(default = "word_emb_default")]
1318 tie_word_embeddings: bool,
1319}
1320
1321impl Phi2BasicConfig {
1322 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi2::Config> {
1323 let basic_config: Self = serde_json::from_str(slice)?;
1324 Ok(models::phi2::Config {
1325 vocab_size: basic_config.vocab_size,
1326 hidden_size: basic_config.hidden_size,
1327 intermediate_size: basic_config.intermediate_size,
1328 num_hidden_layers: basic_config.num_hidden_layers,
1329 num_attention_heads: basic_config.num_attention_heads,
1330 num_key_value_heads: basic_config.num_key_value_heads,
1331 hidden_act: basic_config.hidden_act,
1332 max_position_embeddings: basic_config.max_position_embeddings,
1333 rope_theta: basic_config.rope_theta,
1334 layer_norm_eps: basic_config.layer_norm_eps,
1335 partial_rotary_factor: basic_config.partial_rotary_factor,
1336 qk_layernorm: basic_config.qk_layernorm,
1337 use_flash_attn,
1338 quantization_config: basic_config.quantization_config,
1339 tie_word_embeddings: basic_config.tie_word_embeddings,
1340 })
1341 }
1342}
1343
1344pub struct Phi2Loader;
1348
1349impl NormalModelLoader for Phi2Loader {
1350 fn load(
1351 &self,
1352 config: &str,
1353 use_flash_attn: bool,
1354 vb: ShardedVarBuilder,
1355 normal_loading_metadata: NormalLoadingMetadata,
1356 attention_mechanism: AttentionImplementation,
1357 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1358 Ok(Box::new(models::phi2::Model::new(
1359 &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1360 vb,
1361 self.is_gptx(config)?,
1362 normal_loading_metadata,
1363 attention_mechanism,
1364 )?))
1365 }
1366 fn load_xlora(
1367 &self,
1368 config: &str,
1369 use_flash_attn: bool,
1370 vb: ShardedVarBuilder,
1371 lora_config: &[((String, String), LoraConfig)],
1372 xlora_config: Option<XLoraConfig>,
1373 xlora_ordering: Ordering,
1374 normal_loading_metadata: NormalLoadingMetadata,
1375 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1376 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1377 Ok(Box::new(xlora_models::XLoraPhi2::new(
1378 &Phi2BasicConfig::deserialize(config, use_flash_attn)?,
1379 vb,
1380 lora_config,
1381 xlora_config,
1382 xlora_ordering,
1383 self.is_gptx(config)?,
1384 normal_loading_metadata,
1385 preload_adapters,
1386 )?))
1387 }
1388 fn is_gptx(&self, _: &str) -> Result<bool> {
1389 Ok(true)
1390 }
1391 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1392 Ok(Box::new(Phi2BasicConfig::deserialize(
1393 config,
1394 use_flash_attn,
1395 )?))
1396 }
1397}
1398
1399impl IsqModelLoader for Phi2Loader {
1400 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1401 Ok(vec![
1402 Regex::new(r"lm_head\.(weight|bias)$")?,
1403 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1405 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1406 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1407 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1408 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1410 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1411 ])
1412 }
1413}
1414
1415impl DeviceMappedModelLoader for Phi2Loader {
1416 fn mapped_max_act_size_elems(
1417 &self,
1418 config: &str,
1419 params: &AutoDeviceMapParams,
1420 prompt_chunksize: usize,
1421 ) -> Result<usize> {
1422 let AutoDeviceMapParams::Text {
1423 max_seq_len: _,
1424 max_batch_size,
1425 } = params
1426 else {
1427 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1428 };
1429
1430 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1431
1432 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1433 }
1434 fn non_mapped_max_act_size_elems(
1435 &self,
1436 _config: &str,
1437 _params: &AutoDeviceMapParams,
1438 ) -> Result<usize> {
1439 Ok(0)
1440 }
1441
1442 fn non_mapped_size_in_bytes(
1443 &self,
1444 config: &str,
1445 dtype: DType,
1446 weight_pack_factor: usize,
1447 ) -> Result<usize> {
1448 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1449 let elems = {
1450 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1451 let lm_head = if !cfg.tie_word_embeddings {
1452 cfg.hidden_size * cfg.vocab_size
1453 } else {
1454 0
1455 };
1456 let norm = cfg.hidden_size;
1457 embed_tokens + lm_head + norm
1458 };
1459 Ok(elems * dtype.size_in_bytes())
1460 }
1461
1462 fn layer_sizes_in_bytes(
1463 &self,
1464 config: &str,
1465 dtype: DType,
1466 weight_pack_factor: usize,
1467 ) -> Result<Vec<usize>> {
1468 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1469 let per_layer_elems = {
1470 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1471
1472 let size_in = cfg.hidden_size;
1473 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1474 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1475 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1476 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1477 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1478 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1479 let (q_norm, k_norm) = if cfg.qk_layernorm {
1480 (cfg.head_dim(), cfg.head_dim())
1481 } else {
1482 (0, 0)
1483 };
1484
1485 let h_size = cfg.hidden_size;
1486 let i_size = cfg.intermediate_size;
1487 let fc1 = h_size * i_size / weight_pack_factor;
1488 let fc2 = h_size * i_size / weight_pack_factor;
1489
1490 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1491 };
1492 Ok(vec![
1493 per_layer_elems * dtype.size_in_bytes();
1494 cfg.num_hidden_layers
1495 ])
1496 }
1497
1498 fn num_layers(&self, config: &str) -> Result<usize> {
1499 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1500 Ok(cfg.num_hidden_layers)
1501 }
1502
1503 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1504 let cfg = Phi2BasicConfig::deserialize(config, false)?;
1505
1506 let cfg = ModelConfigMetadata {
1507 max_seq_len: cfg.max_position_embeddings,
1508 num_layers: cfg.num_hidden_layers,
1509 hidden_size: cfg.hidden_size,
1510 num_kv_heads: cfg.num_key_value_heads(),
1511 num_attn_heads: cfg.num_attention_heads,
1512 sliding_window: None,
1513 k_head_dim: cfg.head_dim(),
1514 v_head_dim: cfg.head_dim(),
1515 };
1516
1517 Ok(Box::new(cfg))
1518 }
1519}
1520
1521#[derive(Deserialize)]
1524struct Phi3BasicConfig {
1525 vocab_size: usize,
1526 hidden_act: Activation,
1527 hidden_size: usize,
1528 intermediate_size: usize,
1529 num_hidden_layers: usize,
1530 num_attention_heads: usize,
1531 num_key_value_heads: usize,
1532 rms_norm_eps: f64,
1533 rope_theta: f64,
1534 bos_token_id: Option<u32>,
1535 eos_token_id: Option<u32>,
1536 rope_scaling: Option<PhiRopeScalingConfig>,
1537 max_position_embeddings: usize,
1538 original_max_position_embeddings: usize,
1539 sliding_window: Option<usize>,
1540 quantization_config: Option<QuantizedConfig>,
1541 #[serde(default = "word_emb_default")]
1542 tie_word_embeddings: bool,
1543 partial_rotary_factor: Option<f64>,
1544}
1545
1546impl Phi3BasicConfig {
1547 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3::Config> {
1548 let basic_config: Self = serde_json::from_str(slice)?;
1549 Ok(models::phi3::Config {
1550 vocab_size: basic_config.vocab_size,
1551 hidden_size: basic_config.hidden_size,
1552 intermediate_size: basic_config.intermediate_size,
1553 num_hidden_layers: basic_config.num_hidden_layers,
1554 num_attention_heads: basic_config.num_attention_heads,
1555 num_key_value_heads: basic_config.num_key_value_heads,
1556 hidden_act: basic_config.hidden_act,
1557 max_position_embeddings: basic_config.max_position_embeddings,
1558 rope_theta: basic_config.rope_theta,
1559 rms_norm_eps: basic_config.rms_norm_eps,
1560 eos_token_id: basic_config.eos_token_id,
1561 bos_token_id: basic_config.bos_token_id,
1562 rope_scaling: basic_config.rope_scaling,
1563 original_max_position_embeddings: basic_config.original_max_position_embeddings,
1564 use_flash_attn,
1565 sliding_window: basic_config.sliding_window,
1566 quantization_config: basic_config.quantization_config,
1567 tie_word_embeddings: basic_config.tie_word_embeddings,
1568 partial_rotary_factor: basic_config.partial_rotary_factor,
1569 })
1570 }
1571}
1572
1573pub struct Phi3Loader;
1577
1578impl NormalModelLoader for Phi3Loader {
1579 fn load(
1580 &self,
1581 config: &str,
1582 use_flash_attn: bool,
1583 vb: ShardedVarBuilder,
1584 normal_loading_metadata: NormalLoadingMetadata,
1585 attention_mechanism: AttentionImplementation,
1586 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1587 Ok(Box::new(models::phi3::Model::new(
1588 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1589 vb,
1590 self.is_gptx(config)?,
1591 normal_loading_metadata,
1592 attention_mechanism,
1593 )?))
1594 }
1595 fn load_xlora(
1596 &self,
1597 config: &str,
1598 use_flash_attn: bool,
1599 vb: ShardedVarBuilder,
1600 lora_config: &[((String, String), LoraConfig)],
1601 xlora_config: Option<XLoraConfig>,
1602 xlora_ordering: Ordering,
1603 normal_loading_metadata: NormalLoadingMetadata,
1604 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1605 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1606 Ok(Box::new(xlora_models::XLoraPhi3::new(
1607 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
1608 vb,
1609 lora_config,
1610 xlora_config,
1611 xlora_ordering,
1612 self.is_gptx(config)?,
1613 normal_loading_metadata,
1614 preload_adapters,
1615 )?))
1616 }
1617 fn is_gptx(&self, _: &str) -> Result<bool> {
1618 Ok(true)
1619 }
1620 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1621 Ok(Box::new(Phi3BasicConfig::deserialize(
1622 config,
1623 use_flash_attn,
1624 )?))
1625 }
1626}
1627
1628impl IsqModelLoader for Phi3Loader {
1629 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1630 Ok(vec![
1631 Regex::new(r"lm_head\.(weight|bias)$")?,
1632 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1634 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1635 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1637 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1638 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1639 ])
1640 }
1641}
1642
1643impl DeviceMappedModelLoader for Phi3Loader {
1644 fn mapped_max_act_size_elems(
1645 &self,
1646 config: &str,
1647 params: &AutoDeviceMapParams,
1648 prompt_chunksize: usize,
1649 ) -> Result<usize> {
1650 let AutoDeviceMapParams::Text {
1651 max_seq_len: _,
1652 max_batch_size,
1653 } = params
1654 else {
1655 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1656 };
1657
1658 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1659
1660 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1661 }
1662 fn non_mapped_max_act_size_elems(
1663 &self,
1664 _config: &str,
1665 _params: &AutoDeviceMapParams,
1666 ) -> Result<usize> {
1667 Ok(0)
1668 }
1669
1670 fn non_mapped_size_in_bytes(
1671 &self,
1672 config: &str,
1673 dtype: DType,
1674 weight_pack_factor: usize,
1675 ) -> Result<usize> {
1676 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1677 let elems = {
1678 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1679 let lm_head = if !cfg.tie_word_embeddings {
1680 cfg.hidden_size * cfg.vocab_size
1681 } else {
1682 0
1683 };
1684 let norm = cfg.hidden_size;
1685 embed_tokens + lm_head + norm
1686 };
1687 Ok(elems * dtype.size_in_bytes())
1688 }
1689
1690 fn layer_sizes_in_bytes(
1691 &self,
1692 config: &str,
1693 dtype: DType,
1694 weight_pack_factor: usize,
1695 ) -> Result<Vec<usize>> {
1696 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1697 let per_layer_elems = {
1698 let input_layernorm = cfg.hidden_size;
1699 let post_attention_layernorm = cfg.hidden_size;
1700
1701 let size_in = cfg.hidden_size;
1702 let head_dim = cfg.head_dim();
1703 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1704 let qkv_proj = size_in * op_size / weight_pack_factor;
1705 let o_proj =
1706 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1707
1708 let h_size = cfg.hidden_size;
1709 let i_size = cfg.intermediate_size;
1710 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1711 let down_proj = h_size * i_size / weight_pack_factor;
1712
1713 input_layernorm
1714 + post_attention_layernorm
1715 + qkv_proj
1716 + o_proj
1717 + gate_up_proj
1718 + down_proj
1719 };
1720 Ok(vec![
1721 per_layer_elems * dtype.size_in_bytes();
1722 cfg.num_hidden_layers
1723 ])
1724 }
1725
1726 fn num_layers(&self, config: &str) -> Result<usize> {
1727 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1728 Ok(cfg.num_hidden_layers)
1729 }
1730
1731 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1732 let cfg = Phi3BasicConfig::deserialize(config, false)?;
1733
1734 let cfg = ModelConfigMetadata {
1735 max_seq_len: cfg.max_position_embeddings,
1736 num_layers: cfg.num_hidden_layers,
1737 hidden_size: cfg.hidden_size,
1738 num_kv_heads: cfg.num_key_value_heads,
1739 num_attn_heads: cfg.num_attention_heads,
1740 sliding_window: cfg.sliding_window,
1741 k_head_dim: cfg.head_dim(),
1742 v_head_dim: cfg.head_dim(),
1743 };
1744
1745 Ok(Box::new(cfg))
1746 }
1747}
1748
1749#[derive(Deserialize)]
1752struct Qwen2BasicConfig {
1753 vocab_size: usize,
1754 hidden_size: usize,
1755 intermediate_size: usize,
1756 num_hidden_layers: usize,
1757 num_attention_heads: usize,
1758 num_key_value_heads: usize,
1759 max_position_embeddings: usize,
1760 sliding_window: usize,
1761 rope_theta: f64,
1762 rms_norm_eps: f64,
1763 hidden_act: Activation,
1764 quantization_config: Option<QuantizedConfig>,
1765 tie_word_embeddings: bool,
1766}
1767
1768impl Qwen2BasicConfig {
1769 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::qwen2::Config> {
1770 let basic_config: Self = serde_json::from_str(slice)?;
1771 Ok(models::qwen2::Config {
1772 vocab_size: basic_config.vocab_size,
1773 hidden_size: basic_config.hidden_size,
1774 intermediate_size: basic_config.intermediate_size,
1775 num_hidden_layers: basic_config.num_hidden_layers,
1776 num_attention_heads: basic_config.num_attention_heads,
1777 num_key_value_heads: basic_config.num_key_value_heads,
1778 hidden_act: basic_config.hidden_act,
1779 max_position_embeddings: basic_config.max_position_embeddings,
1780 rope_theta: basic_config.rope_theta,
1781 rms_norm_eps: basic_config.rms_norm_eps,
1782 sliding_window: basic_config.sliding_window,
1783 use_flash_attn,
1784 quantization_config: basic_config.quantization_config,
1785 tie_word_embeddings: basic_config.tie_word_embeddings,
1786 })
1787 }
1788}
1789
1790pub struct Qwen2Loader;
1794
1795impl NormalModelLoader for Qwen2Loader {
1796 fn load(
1797 &self,
1798 config: &str,
1799 use_flash_attn: bool,
1800 vb: ShardedVarBuilder,
1801 normal_loading_metadata: NormalLoadingMetadata,
1802 attention_mechanism: AttentionImplementation,
1803 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1804 Ok(Box::new(models::qwen2::Model::new(
1805 &Qwen2BasicConfig::deserialize(config, use_flash_attn)?,
1806 vb,
1807 self.is_gptx(config)?,
1808 normal_loading_metadata,
1809 attention_mechanism,
1810 )?))
1811 }
1812 fn load_xlora(
1813 &self,
1814 _config: &str,
1815 _use_flash_attn: bool,
1816 _vb: ShardedVarBuilder,
1817 _lora_config: &[((String, String), LoraConfig)],
1818 _xlora_config: Option<XLoraConfig>,
1819 _xlora_ordering: Ordering,
1820 _normal_loading_metadata: NormalLoadingMetadata,
1821 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1822 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1823 todo!()
1824 }
1825 fn is_gptx(&self, _: &str) -> Result<bool> {
1826 Ok(true)
1827 }
1828 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1829 Ok(Box::new(Qwen2BasicConfig::deserialize(
1830 config,
1831 use_flash_attn,
1832 )?))
1833 }
1834}
1835
1836impl IsqModelLoader for Qwen2Loader {
1837 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1838 Ok(vec![
1839 Regex::new(r"lm_head\.(weight|bias)$")?,
1840 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1842 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1843 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1844 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1845 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1847 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1848 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1849 ])
1850 }
1851}
1852
1853impl DeviceMappedModelLoader for Qwen2Loader {
1854 fn mapped_max_act_size_elems(
1855 &self,
1856 config: &str,
1857 params: &AutoDeviceMapParams,
1858 prompt_chunksize: usize,
1859 ) -> Result<usize> {
1860 let AutoDeviceMapParams::Text {
1861 max_seq_len: _,
1862 max_batch_size,
1863 } = params
1864 else {
1865 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1866 };
1867
1868 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1869
1870 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1871 }
1872 fn non_mapped_max_act_size_elems(
1873 &self,
1874 _config: &str,
1875 _params: &AutoDeviceMapParams,
1876 ) -> Result<usize> {
1877 Ok(0)
1878 }
1879
1880 fn non_mapped_size_in_bytes(
1881 &self,
1882 config: &str,
1883 dtype: DType,
1884 weight_pack_factor: usize,
1885 ) -> Result<usize> {
1886 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1887 let elems = {
1888 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1889 let lm_head = if !cfg.tie_word_embeddings {
1890 cfg.hidden_size * cfg.vocab_size
1891 } else {
1892 0
1893 };
1894 let norm = cfg.hidden_size;
1895 embed_tokens + lm_head + norm
1896 };
1897 Ok(elems * dtype.size_in_bytes())
1898 }
1899
1900 fn layer_sizes_in_bytes(
1901 &self,
1902 config: &str,
1903 dtype: DType,
1904 weight_pack_factor: usize,
1905 ) -> Result<Vec<usize>> {
1906 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1907 let per_layer_elems = {
1908 let input_layernorm = cfg.hidden_size;
1909 let post_attention_layernorm = cfg.hidden_size;
1910
1911 let size_in = cfg.hidden_size;
1912 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1913 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1914 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1915 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1916 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1917 let o_proj = size_q * size_in / weight_pack_factor;
1918
1919 let h_size = cfg.hidden_size;
1920 let i_size = cfg.intermediate_size;
1921 let gate_proj = h_size * i_size / weight_pack_factor;
1922 let up_proj = h_size * i_size / weight_pack_factor;
1923 let down_proj = i_size * h_size / weight_pack_factor;
1924
1925 input_layernorm
1926 + post_attention_layernorm
1927 + q_proj
1928 + k_proj
1929 + v_proj
1930 + o_proj
1931 + gate_proj
1932 + up_proj
1933 + down_proj
1934 };
1935 Ok(vec![
1936 per_layer_elems * dtype.size_in_bytes();
1937 cfg.num_hidden_layers
1938 ])
1939 }
1940
1941 fn num_layers(&self, config: &str) -> Result<usize> {
1942 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1943 Ok(cfg.num_hidden_layers)
1944 }
1945
1946 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1947 let cfg = Qwen2BasicConfig::deserialize(config, false)?;
1948
1949 let cfg = ModelConfigMetadata {
1950 max_seq_len: cfg.max_position_embeddings,
1951 num_layers: cfg.num_hidden_layers,
1952 hidden_size: cfg.hidden_size,
1953 num_kv_heads: cfg.num_key_value_heads,
1954 num_attn_heads: cfg.num_attention_heads,
1955 sliding_window: Some(cfg.sliding_window),
1956 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1957 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1958 };
1959
1960 Ok(Box::new(cfg))
1961 }
1962}
1963
1964#[derive(Deserialize)]
1967struct Gemma2BasicConfig {
1968 attention_bias: bool,
1969 head_dim: usize,
1970 hidden_act: Option<Activation>,
1972 hidden_activation: Option<Activation>,
1973 hidden_size: usize,
1974 intermediate_size: usize,
1975 num_attention_heads: usize,
1976 num_hidden_layers: usize,
1977 num_key_value_heads: usize,
1978 rms_norm_eps: f64,
1979 rope_theta: f64,
1980 vocab_size: usize,
1981 sliding_window: usize,
1982 attn_logit_softcapping: Option<f64>,
1983 final_logit_softcapping: Option<f64>,
1984 query_pre_attn_scalar: usize,
1985
1986 #[serde(default = "default_max_position_embeddings")]
1987 max_position_embeddings: usize,
1988 quantization_config: Option<QuantizedConfig>,
1989 #[serde(default = "word_emb_default")]
1990 tie_word_embeddings: bool,
1991}
1992
1993impl Gemma2BasicConfig {
1994 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::gemma2::Config> {
1995 let basic_config: Self = serde_json::from_str(slice)?;
1996 Ok(models::gemma2::Config {
1997 vocab_size: basic_config.vocab_size,
1998 hidden_size: basic_config.hidden_size,
1999 intermediate_size: basic_config.intermediate_size,
2000 num_hidden_layers: basic_config.num_hidden_layers,
2001 num_attention_heads: basic_config.num_attention_heads,
2002 num_key_value_heads: basic_config.num_key_value_heads,
2003 hidden_act: basic_config.hidden_act,
2004 hidden_activation: basic_config.hidden_activation,
2005 max_position_embeddings: basic_config.max_position_embeddings,
2006 rms_norm_eps: basic_config.rms_norm_eps,
2007 rope_theta: basic_config.rope_theta,
2008 attention_bias: basic_config.attention_bias,
2009 head_dim: basic_config.head_dim,
2010 use_flash_attn,
2011 quantization_config: basic_config.quantization_config,
2012 sliding_window: basic_config.sliding_window,
2013 attn_logit_softcapping: basic_config.attn_logit_softcapping,
2014 final_logit_softcapping: basic_config.final_logit_softcapping,
2015 query_pre_attn_scalar: basic_config.query_pre_attn_scalar,
2016 tie_word_embeddings: basic_config.tie_word_embeddings,
2017 })
2018 }
2019}
2020
2021pub struct Gemma2Loader;
2025
2026impl NormalModelLoader for Gemma2Loader {
2027 fn load(
2028 &self,
2029 config: &str,
2030 use_flash_attn: bool,
2031 vb: ShardedVarBuilder,
2032 normal_loading_metadata: NormalLoadingMetadata,
2033 attention_mechanism: AttentionImplementation,
2034 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2035 Ok(Box::new(models::gemma2::Model::new(
2036 &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2037 vb,
2038 self.is_gptx(config)?,
2039 normal_loading_metadata,
2040 attention_mechanism,
2041 )?))
2042 }
2043 fn load_xlora(
2044 &self,
2045 config: &str,
2046 use_flash_attn: bool,
2047 vb: ShardedVarBuilder,
2048 lora_config: &[((String, String), LoraConfig)],
2049 xlora_config: Option<XLoraConfig>,
2050 xlora_ordering: Ordering,
2051 normal_loading_metadata: NormalLoadingMetadata,
2052 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2053 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2054 Ok(Box::new(xlora_models::XLoraGemma2::new(
2055 &Gemma2BasicConfig::deserialize(config, use_flash_attn)?,
2056 vb,
2057 lora_config,
2058 xlora_config,
2059 xlora_ordering,
2060 self.is_gptx(config)?,
2061 normal_loading_metadata,
2062 preload_adapters,
2063 )?))
2064 }
2065 fn is_gptx(&self, _: &str) -> Result<bool> {
2066 Ok(true)
2067 }
2068 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2069 Ok(Box::new(Gemma2BasicConfig::deserialize(
2071 config,
2072 use_flash_attn,
2073 )?))
2074 }
2075}
2076
2077impl IsqModelLoader for Gemma2Loader {
2078 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2079 Ok(vec![
2080 Regex::new(r"lm_head\.(weight|bias)$")?,
2081 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2083 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2084 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2085 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2086 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2088 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2089 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2090 ])
2091 }
2092}
2093
2094impl DeviceMappedModelLoader for Gemma2Loader {
2095 fn mapped_max_act_size_elems(
2096 &self,
2097 config: &str,
2098 params: &AutoDeviceMapParams,
2099 prompt_chunksize: usize,
2100 ) -> Result<usize> {
2101 let AutoDeviceMapParams::Text {
2102 max_seq_len: _,
2103 max_batch_size,
2104 } = params
2105 else {
2106 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2107 };
2108
2109 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2110
2111 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2112 }
2113 fn non_mapped_max_act_size_elems(
2114 &self,
2115 _config: &str,
2116 _params: &AutoDeviceMapParams,
2117 ) -> Result<usize> {
2118 Ok(0)
2119 }
2120
2121 fn non_mapped_size_in_bytes(
2122 &self,
2123 config: &str,
2124 dtype: DType,
2125 weight_pack_factor: usize,
2126 ) -> Result<usize> {
2127 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2128 let elems = {
2129 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2130 let lm_head = if !cfg.tie_word_embeddings {
2131 cfg.hidden_size * cfg.vocab_size
2132 } else {
2133 0
2134 };
2135 let norm = cfg.hidden_size;
2136 embed_tokens + lm_head + norm
2137 };
2138 Ok(elems * dtype.size_in_bytes())
2139 }
2140
2141 fn layer_sizes_in_bytes(
2142 &self,
2143 config: &str,
2144 dtype: DType,
2145 weight_pack_factor: usize,
2146 ) -> Result<Vec<usize>> {
2147 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2148 let per_layer_elems = {
2149 let input_layernorm = cfg.hidden_size;
2150 let post_attention_layernorm = cfg.hidden_size;
2151
2152 let size_in = cfg.hidden_size;
2153 let size_q = cfg.head_dim * cfg.num_attention_heads;
2154 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
2155 let q_proj =
2156 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2157 let k_proj =
2158 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2159 let v_proj =
2160 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2161 let o_proj =
2162 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2163
2164 let h_size = cfg.hidden_size;
2165 let i_size = cfg.intermediate_size;
2166 let gate_proj = h_size * i_size / weight_pack_factor;
2167 let up_proj = h_size * i_size / weight_pack_factor;
2168 let down_proj = i_size * h_size / weight_pack_factor;
2169
2170 input_layernorm
2171 + post_attention_layernorm
2172 + q_proj
2173 + k_proj
2174 + v_proj
2175 + o_proj
2176 + gate_proj
2177 + up_proj
2178 + down_proj
2179 };
2180 Ok(vec![
2181 per_layer_elems * dtype.size_in_bytes();
2182 cfg.num_hidden_layers
2183 ])
2184 }
2185
2186 fn num_layers(&self, config: &str) -> Result<usize> {
2187 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2188 Ok(cfg.num_hidden_layers)
2189 }
2190 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2191 let cfg = Gemma2BasicConfig::deserialize(config, false)?;
2192
2193 let cfg = ModelConfigMetadata {
2194 max_seq_len: cfg.max_position_embeddings,
2195 num_layers: cfg.num_hidden_layers,
2196 hidden_size: cfg.hidden_size,
2197 num_kv_heads: cfg.num_key_value_heads,
2198 num_attn_heads: cfg.num_attention_heads,
2199 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2201 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2202 };
2203
2204 Ok(Box::new(cfg))
2205 }
2206}
2207
2208#[derive(Deserialize, Debug)]
2211struct Starcoder2BasicConfig {
2212 vocab_size: usize,
2213 hidden_size: usize,
2214 intermediate_size: usize,
2215 num_hidden_layers: usize,
2216 num_attention_heads: usize,
2217 num_key_value_heads: usize,
2218 hidden_act: Activation,
2219 max_position_embeddings: usize,
2220 norm_epsilon: f64,
2221 rope_theta: f64,
2222 use_bias: bool,
2223 sliding_window: Option<usize>,
2224 quantization_config: Option<QuantizedConfig>,
2225 #[serde(default = "word_emb_default")]
2226 tie_word_embeddings: bool,
2227}
2228
2229impl Starcoder2BasicConfig {
2230 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::starcoder2::Config> {
2231 let basic_config: Self = serde_json::from_str(slice)?;
2232 Ok(models::starcoder2::Config {
2233 vocab_size: basic_config.vocab_size,
2234 hidden_size: basic_config.hidden_size,
2235 intermediate_size: basic_config.intermediate_size,
2236 num_hidden_layers: basic_config.num_hidden_layers,
2237 num_attention_heads: basic_config.num_attention_heads,
2238 num_key_value_heads: basic_config.num_key_value_heads,
2239 hidden_act: basic_config.hidden_act,
2240 max_position_embeddings: basic_config.max_position_embeddings,
2241 rope_theta: basic_config.rope_theta,
2242 sliding_window: basic_config.sliding_window,
2243 use_flash_attn,
2244 norm_epsilon: basic_config.norm_epsilon,
2245 use_bias: basic_config.use_bias,
2246 quantization_config: basic_config.quantization_config,
2247 tie_word_embeddings: basic_config.tie_word_embeddings,
2248 })
2249 }
2250}
2251
2252pub struct Starcoder2Loader;
2256
2257impl NormalModelLoader for Starcoder2Loader {
2258 fn load(
2259 &self,
2260 config: &str,
2261 use_flash_attn: bool,
2262 vb: ShardedVarBuilder,
2263 normal_loading_metadata: NormalLoadingMetadata,
2264 attention_mechanism: AttentionImplementation,
2265 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2266 Ok(Box::new(models::starcoder2::Model::new(
2267 &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2268 vb,
2269 self.is_gptx(config)?,
2270 normal_loading_metadata,
2271 attention_mechanism,
2272 )?))
2273 }
2274 fn load_xlora(
2275 &self,
2276 config: &str,
2277 use_flash_attn: bool,
2278 vb: ShardedVarBuilder,
2279 lora_config: &[((String, String), LoraConfig)],
2280 xlora_config: Option<XLoraConfig>,
2281 xlora_ordering: Ordering,
2282 normal_loading_metadata: NormalLoadingMetadata,
2283 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2284 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2285 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2286 &Starcoder2BasicConfig::deserialize(config, use_flash_attn)?,
2287 vb,
2288 lora_config,
2289 xlora_config,
2290 xlora_ordering,
2291 self.is_gptx(config)?,
2292 normal_loading_metadata,
2293 preload_adapters,
2294 )?))
2295 }
2296 fn is_gptx(&self, _: &str) -> Result<bool> {
2297 Ok(true)
2298 }
2299 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2300 Ok(Box::new(serde_json::from_str::<Starcoder2BasicConfig>(
2301 config,
2302 )?))
2303 }
2304}
2305
2306impl IsqModelLoader for Starcoder2Loader {
2307 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2308 Ok(vec![
2309 Regex::new(r"lm_head\.(weight|bias)$")?,
2310 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2312 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2313 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2314 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2315 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2317 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2318 ])
2319 }
2320}
2321
2322impl DeviceMappedModelLoader for Starcoder2Loader {
2323 fn mapped_max_act_size_elems(
2324 &self,
2325 config: &str,
2326 params: &AutoDeviceMapParams,
2327 prompt_chunksize: usize,
2328 ) -> Result<usize> {
2329 let AutoDeviceMapParams::Text {
2330 max_seq_len: _,
2331 max_batch_size,
2332 } = params
2333 else {
2334 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2335 };
2336
2337 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2338
2339 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2340 }
2341 fn non_mapped_max_act_size_elems(
2342 &self,
2343 _config: &str,
2344 _params: &AutoDeviceMapParams,
2345 ) -> Result<usize> {
2346 Ok(0)
2347 }
2348
2349 fn non_mapped_size_in_bytes(
2350 &self,
2351 config: &str,
2352 dtype: DType,
2353 weight_pack_factor: usize,
2354 ) -> Result<usize> {
2355 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2356 let elems = {
2357 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2358 let lm_head = if !cfg.tie_word_embeddings {
2359 cfg.hidden_size * cfg.vocab_size
2360 } else {
2361 0
2362 };
2363 let norm = cfg.hidden_size + cfg.hidden_size;
2364 embed_tokens + lm_head + norm
2365 };
2366 Ok(elems * dtype.size_in_bytes())
2367 }
2368
2369 fn layer_sizes_in_bytes(
2370 &self,
2371 config: &str,
2372 dtype: DType,
2373 weight_pack_factor: usize,
2374 ) -> Result<Vec<usize>> {
2375 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2376 let per_layer_elems = {
2377 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2378 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2379
2380 let size_in = cfg.hidden_size;
2381 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2382 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2383 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2384 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2385 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2386 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2387
2388 let h_size = cfg.hidden_size;
2389 let i_size = cfg.intermediate_size;
2390 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2391 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2392
2393 input_layernorm
2394 + post_attention_layernorm
2395 + q_proj
2396 + k_proj
2397 + v_proj
2398 + o_proj
2399 + fc1
2400 + fc2
2401 };
2402 Ok(vec![
2403 per_layer_elems * dtype.size_in_bytes();
2404 cfg.num_hidden_layers
2405 ])
2406 }
2407
2408 fn num_layers(&self, config: &str) -> Result<usize> {
2409 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2410 Ok(cfg.num_hidden_layers)
2411 }
2412
2413 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2414 let cfg = Starcoder2BasicConfig::deserialize(config, false)?;
2415
2416 let cfg = ModelConfigMetadata {
2417 max_seq_len: cfg.max_position_embeddings,
2418 num_layers: cfg.num_hidden_layers,
2419 hidden_size: cfg.hidden_size,
2420 num_kv_heads: cfg.num_key_value_heads,
2421 num_attn_heads: cfg.num_attention_heads,
2422 sliding_window: cfg.sliding_window,
2423 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2424 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2425 };
2426
2427 Ok(Box::new(cfg))
2428 }
2429}
2430
2431#[derive(Deserialize)]
2434struct Phi3_5MoEBasicConfig {
2435 vocab_size: usize,
2436 hidden_act: Activation,
2437 hidden_size: usize,
2438 intermediate_size: usize,
2439 num_hidden_layers: usize,
2440 num_attention_heads: usize,
2441 num_key_value_heads: usize,
2442 rms_norm_eps: f64,
2443 rope_theta: f64,
2444 rope_scaling: Option<PhiRopeScalingConfig>,
2445 max_position_embeddings: usize,
2446 original_max_position_embeddings: usize,
2447 sliding_window: Option<usize>,
2448 quantization_config: Option<QuantizedConfig>,
2449 lm_head_bias: bool,
2450 attention_bias: bool,
2451 num_local_experts: usize,
2452 router_jitter_noise: f64,
2453 #[serde(default = "word_emb_default")]
2454 tie_word_embeddings: bool,
2455}
2456
2457impl Phi3_5MoEBasicConfig {
2458 fn deserialize(slice: &str, use_flash_attn: bool) -> Result<models::phi3_5_moe::Config> {
2459 let basic_config: Self = serde_json::from_str(slice)?;
2460 Ok(models::phi3_5_moe::Config {
2461 vocab_size: basic_config.vocab_size,
2462 hidden_size: basic_config.hidden_size,
2463 intermediate_size: basic_config.intermediate_size,
2464 num_hidden_layers: basic_config.num_hidden_layers,
2465 num_attention_heads: basic_config.num_attention_heads,
2466 num_key_value_heads: basic_config.num_key_value_heads,
2467 hidden_act: basic_config.hidden_act,
2468 max_position_embeddings: basic_config.max_position_embeddings,
2469 rope_theta: basic_config.rope_theta,
2470 rms_norm_eps: basic_config.rms_norm_eps,
2471 rope_scaling: basic_config.rope_scaling,
2472 original_max_position_embeddings: basic_config.original_max_position_embeddings,
2473 use_flash_attn,
2474 sliding_window: basic_config.sliding_window,
2475 quantization_config: basic_config.quantization_config,
2476 lm_head_bias: basic_config.lm_head_bias,
2477 attention_bias: basic_config.attention_bias,
2478 num_local_experts: basic_config.num_local_experts,
2479 router_jitter_noise: basic_config.router_jitter_noise,
2480 tie_word_embeddings: basic_config.tie_word_embeddings,
2481 })
2482 }
2483}
2484
2485pub struct Phi3_5MoELoader;
2489
2490impl NormalModelLoader for Phi3_5MoELoader {
2491 fn load(
2492 &self,
2493 config: &str,
2494 use_flash_attn: bool,
2495 vb: ShardedVarBuilder,
2496 normal_loading_metadata: NormalLoadingMetadata,
2497 attention_mechanism: AttentionImplementation,
2498 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2499 Ok(Box::new(models::phi3_5_moe::Model::new(
2500 &Phi3_5MoEBasicConfig::deserialize(config, use_flash_attn)?,
2501 vb,
2502 self.is_gptx(config)?,
2503 normal_loading_metadata,
2504 attention_mechanism,
2505 )?))
2506 }
2507 fn load_xlora(
2508 &self,
2509 config: &str,
2510 use_flash_attn: bool,
2511 vb: ShardedVarBuilder,
2512 lora_config: &[((String, String), LoraConfig)],
2513 xlora_config: Option<XLoraConfig>,
2514 xlora_ordering: Ordering,
2515 normal_loading_metadata: NormalLoadingMetadata,
2516 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2517 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2518 Ok(Box::new(xlora_models::XLoraPhi3::new(
2519 &Phi3BasicConfig::deserialize(config, use_flash_attn)?,
2520 vb,
2521 lora_config,
2522 xlora_config,
2523 xlora_ordering,
2524 self.is_gptx(config)?,
2525 normal_loading_metadata,
2526 preload_adapters,
2527 )?))
2528 }
2529 fn is_gptx(&self, _: &str) -> Result<bool> {
2530 Ok(true)
2531 }
2532 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2533 Ok(Box::new(Phi3_5MoEBasicConfig::deserialize(
2534 config,
2535 use_flash_attn,
2536 )?))
2537 }
2538}
2539
2540impl IsqModelLoader for Phi3_5MoELoader {
2541 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2542 Ok(vec![
2543 Regex::new(r"lm_head\.(weight|bias)$")?,
2544 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2546 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2547 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2548 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2549 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2551 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2552 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2553 ])
2554 }
2555
2556 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2557 Ok(vec![
2558 Regex::new(r"lm_head\.(weight|bias)$")?,
2559 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2561 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2562 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2563 ])
2564 }
2565}
2566
2567impl DeviceMappedModelLoader for Phi3_5MoELoader {
2568 fn mapped_max_act_size_elems(
2569 &self,
2570 config: &str,
2571 params: &AutoDeviceMapParams,
2572 prompt_chunksize: usize,
2573 ) -> Result<usize> {
2574 let AutoDeviceMapParams::Text {
2575 max_seq_len: _,
2576 max_batch_size,
2577 } = params
2578 else {
2579 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2580 };
2581
2582 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2583
2584 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2585 }
2586 fn non_mapped_max_act_size_elems(
2587 &self,
2588 _config: &str,
2589 _params: &AutoDeviceMapParams,
2590 ) -> Result<usize> {
2591 Ok(0)
2592 }
2593
2594 fn non_mapped_size_in_bytes(
2595 &self,
2596 config: &str,
2597 dtype: DType,
2598 weight_pack_factor: usize,
2599 ) -> Result<usize> {
2600 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2601 let elems = {
2602 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2603 let lm_head = if !cfg.tie_word_embeddings {
2604 cfg.hidden_size * cfg.vocab_size
2605 } else {
2606 0
2607 };
2608 let norm = cfg.hidden_size;
2609 embed_tokens + lm_head + norm
2610 };
2611 Ok(elems * dtype.size_in_bytes())
2612 }
2613
2614 fn layer_sizes_in_bytes(
2615 &self,
2616 config: &str,
2617 dtype: DType,
2618 weight_pack_factor: usize,
2619 ) -> Result<Vec<usize>> {
2620 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2621 let per_layer_elems = {
2622 let input_layernorm = cfg.hidden_size;
2623 let post_attention_layernorm = cfg.hidden_size;
2624
2625 let size_in = cfg.hidden_size;
2626 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2627 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2628 let q_proj =
2629 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2630 let k_proj =
2631 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2632 let v_proj =
2633 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2634 let o_proj =
2635 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2636
2637 let moe_block = {
2638 let gate = cfg.hidden_size * cfg.num_local_experts;
2639 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2641 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2642 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2643 gate + cfg.num_local_experts * w1
2644 + cfg.num_local_experts * w2
2645 + cfg.num_local_experts * w3
2646 };
2647
2648 input_layernorm
2649 + post_attention_layernorm
2650 + q_proj
2651 + k_proj
2652 + v_proj
2653 + o_proj
2654 + moe_block
2655 };
2656 Ok(vec![
2657 per_layer_elems * dtype.size_in_bytes();
2658 cfg.num_hidden_layers
2659 ])
2660 }
2661
2662 fn num_layers(&self, config: &str) -> Result<usize> {
2663 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2664 Ok(cfg.num_hidden_layers)
2665 }
2666
2667 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2668 let cfg = Phi3_5MoEBasicConfig::deserialize(config, false)?;
2669
2670 let cfg = ModelConfigMetadata {
2671 max_seq_len: cfg.max_position_embeddings,
2672 num_layers: cfg.num_hidden_layers,
2673 hidden_size: cfg.hidden_size,
2674 num_kv_heads: cfg.num_key_value_heads,
2675 num_attn_heads: cfg.num_attention_heads,
2676 sliding_window: cfg.sliding_window,
2677 k_head_dim: cfg.head_dim(),
2678 v_head_dim: cfg.head_dim(),
2679 };
2680
2681 Ok(Box::new(cfg))
2682 }
2683}
2684
2685pub struct DeepSeekV2Loader;
2689
2690impl NormalModelLoader for DeepSeekV2Loader {
2691 fn load(
2692 &self,
2693 config: &str,
2694 use_flash_attn: bool,
2695 vb: ShardedVarBuilder,
2696 normal_loading_metadata: NormalLoadingMetadata,
2697 attention_mechanism: AttentionImplementation,
2698 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2699 let mut cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2700 cfg.use_flash_attn = use_flash_attn;
2701 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2702 &cfg,
2703 vb,
2704 self.is_gptx(config)?,
2705 normal_loading_metadata,
2706 attention_mechanism,
2707 )?))
2708 }
2709 fn load_xlora(
2710 &self,
2711 _config: &str,
2712 _use_flash_attn: bool,
2713 _vb: ShardedVarBuilder,
2714 _lora_config: &[((String, String), LoraConfig)],
2715 _xlora_config: Option<XLoraConfig>,
2716 _xlora_ordering: Ordering,
2717 _normal_loading_metadata: NormalLoadingMetadata,
2718 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2719 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2720 todo!()
2721 }
2722 fn is_gptx(&self, _: &str) -> Result<bool> {
2723 Ok(true)
2724 }
2725 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2726 let mut config: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2727 config.use_flash_attn = use_flash_attn;
2728 Ok(Box::new(config))
2729 }
2730}
2731
2732impl IsqModelLoader for DeepSeekV2Loader {
2733 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2734 let mut data = vec![
2735 Regex::new(r"lm_head\.(weight|bias)$")?,
2736 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2738 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2739 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2740 ];
2741 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2742 if cfg.q_lora_rank.is_some() {
2743 data.extend(vec![
2744 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2745 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2746 ]);
2747 } else {
2748 data.push(Regex::new(
2749 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2750 )?);
2751 }
2752 for layer_idx in 0..cfg.num_hidden_layers {
2753 if cfg.n_routed_experts.is_some()
2754 && layer_idx >= cfg.first_k_dense_replace
2755 && layer_idx % cfg.moe_layer_freq == 0
2756 {
2757 for i in 0..cfg.n_routed_experts.unwrap() {
2758 data.extend(vec![
2759 Regex::new(&format!(
2760 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2761 ))?,
2762 Regex::new(&format!(
2763 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2764 ))?,
2765 Regex::new(&format!(
2766 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2767 ))?,
2768 ]);
2769 }
2770 if cfg.n_shared_experts.is_some() {
2771 data.extend(vec![
2772 Regex::new(&format!(
2773 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2774 ))?,
2775 Regex::new(&format!(
2776 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2777 ))?,
2778 Regex::new(&format!(
2779 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2780 ))?,
2781 ]);
2782 }
2783 } else {
2784 data.extend(vec![
2785 Regex::new(&format!(
2786 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2787 ))?,
2788 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2789 Regex::new(&format!(
2790 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2791 ))?,
2792 ]);
2793 };
2794 }
2795 Ok(data)
2796 }
2797
2798 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2799 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2800 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2801 for layer_idx in 0..cfg.num_hidden_layers {
2802 if cfg.n_routed_experts.is_some()
2803 && layer_idx >= cfg.first_k_dense_replace
2804 && layer_idx % cfg.moe_layer_freq == 0
2805 {
2806 for i in 0..cfg.n_routed_experts.unwrap() {
2807 data.extend(vec![
2808 Regex::new(&format!(
2809 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2810 ))?,
2811 Regex::new(&format!(
2812 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2813 ))?,
2814 Regex::new(&format!(
2815 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2816 ))?,
2817 ]);
2818 }
2819 if cfg.n_shared_experts.is_some() {
2820 data.extend(vec![
2821 Regex::new(&format!(
2822 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2823 ))?,
2824 Regex::new(&format!(
2825 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2826 ))?,
2827 Regex::new(&format!(
2828 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2829 ))?,
2830 ]);
2831 }
2832 } else {
2833 data.extend(vec![
2834 Regex::new(&format!(
2835 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2836 ))?,
2837 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2838 Regex::new(&format!(
2839 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2840 ))?,
2841 ]);
2842 };
2843 }
2844 Ok(data)
2845 }
2846}
2847
2848impl DeviceMappedModelLoader for DeepSeekV2Loader {
2849 fn mapped_max_act_size_elems(
2850 &self,
2851 config: &str,
2852 params: &AutoDeviceMapParams,
2853 prompt_chunksize: usize,
2854 ) -> Result<usize> {
2855 let AutoDeviceMapParams::Text {
2856 max_seq_len: _,
2857 max_batch_size,
2858 } = params
2859 else {
2860 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2861 };
2862
2863 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2864
2865 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2866 }
2867 fn non_mapped_max_act_size_elems(
2868 &self,
2869 _config: &str,
2870 _params: &AutoDeviceMapParams,
2871 ) -> Result<usize> {
2872 Ok(0)
2873 }
2874
2875 fn non_mapped_size_in_bytes(
2876 &self,
2877 config: &str,
2878 dtype: DType,
2879 weight_pack_factor: usize,
2880 ) -> Result<usize> {
2881 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2882 let elems = {
2883 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2884 let lm_head = if !cfg.tie_word_embeddings {
2885 cfg.hidden_size * cfg.vocab_size
2886 } else {
2887 0
2888 };
2889 let norm = cfg.hidden_size;
2890 embed_tokens + lm_head + norm
2891 };
2892 Ok(elems * dtype.size_in_bytes())
2893 }
2894
2895 fn layer_sizes_in_bytes(
2896 &self,
2897 config: &str,
2898 dtype: DType,
2899 weight_pack_factor: usize,
2900 ) -> Result<Vec<usize>> {
2901 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2902 let mut per_layer_elems = Vec::new();
2903
2904 for layer_idx in 0..cfg.num_hidden_layers {
2905 let input_layernorm = cfg.hidden_size;
2906 let post_attention_layernorm = cfg.hidden_size;
2907
2908 let q_proj = match cfg.q_lora_rank {
2909 Some(lora_rank) => {
2910 let a = cfg.hidden_size * lora_rank;
2911 let norm = lora_rank;
2912 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2913 a + norm + b
2914 }
2915 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2916 };
2917 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2918 / weight_pack_factor
2919 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2920 let kv_a_layernorm = cfg.kv_lora_rank;
2921 let kv_b_proj = cfg.kv_lora_rank
2922 * cfg.num_attention_heads
2923 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2924 / weight_pack_factor;
2925 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2926 / weight_pack_factor
2927 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2928
2929 let moe_block = {
2930 let mut sum = 0;
2931 if cfg.n_routed_experts.is_some()
2932 && layer_idx >= cfg.first_k_dense_replace
2933 && layer_idx % cfg.moe_layer_freq == 0
2934 {
2935 let h_size = cfg.hidden_size;
2936 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2937 * cfg.n_routed_experts.unwrap();
2938 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2939 * cfg.n_routed_experts.unwrap();
2940 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2941 * cfg.n_routed_experts.unwrap();
2942 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2943 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2944 / weight_pack_factor;
2945 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2946 / weight_pack_factor;
2947 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2948 / weight_pack_factor;
2949 gate_proj + up_proj + down_proj
2950 } else {
2951 0
2952 };
2953 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2954 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2955 } else {
2956 let h_size = cfg.hidden_size;
2957 let i_size = cfg.intermediate_size;
2958 let gate_proj = h_size * i_size / weight_pack_factor;
2959 let up_proj = h_size * i_size / weight_pack_factor;
2960 let down_proj = i_size * h_size / weight_pack_factor;
2961 sum += gate_proj + up_proj + down_proj;
2962 }
2963 sum
2964 };
2965
2966 per_layer_elems.push(
2967 input_layernorm
2968 + post_attention_layernorm
2969 + q_proj
2970 + kv_a_layernorm
2971 + kv_a_proj_with_mqa
2972 + kv_b_proj
2973 + o_proj
2974 + moe_block,
2975 );
2976 }
2977
2978 Ok(per_layer_elems
2979 .into_iter()
2980 .map(|x| x * dtype.size_in_bytes())
2981 .collect())
2982 }
2983
2984 fn num_layers(&self, config: &str) -> Result<usize> {
2985 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2986 Ok(cfg.num_hidden_layers)
2987 }
2988
2989 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2990 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2991
2992 let cfg = ModelConfigMetadata {
2993 max_seq_len: cfg.max_position_embeddings,
2994 num_layers: cfg.num_hidden_layers,
2995 hidden_size: cfg.hidden_size,
2996 num_kv_heads: cfg.num_attention_heads,
2997 num_attn_heads: cfg.num_attention_heads,
2998 sliding_window: None,
2999 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3000 v_head_dim: cfg.v_head_dim,
3001 };
3002
3003 Ok(Box::new(cfg))
3004 }
3005}
3006
3007pub struct DeepSeekV3Loader;
3011
3012impl NormalModelLoader for DeepSeekV3Loader {
3013 fn load(
3014 &self,
3015 config: &str,
3016 use_flash_attn: bool,
3017 vb: ShardedVarBuilder,
3018 normal_loading_metadata: NormalLoadingMetadata,
3019 attention_mechanism: AttentionImplementation,
3020 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3021 let mut cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3022 cfg.use_flash_attn = use_flash_attn;
3023 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
3024 &cfg,
3025 vb,
3026 self.is_gptx(config)?,
3027 normal_loading_metadata,
3028 attention_mechanism,
3029 )?))
3030 }
3031 fn load_xlora(
3032 &self,
3033 _config: &str,
3034 _use_flash_attn: bool,
3035 _vb: ShardedVarBuilder,
3036 _lora_config: &[((String, String), LoraConfig)],
3037 _xlora_config: Option<XLoraConfig>,
3038 _xlora_ordering: Ordering,
3039 _normal_loading_metadata: NormalLoadingMetadata,
3040 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3041 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3042 todo!()
3043 }
3044 fn is_gptx(&self, _: &str) -> Result<bool> {
3045 Ok(true)
3046 }
3047 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3048 let mut config: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3049 config.use_flash_attn = use_flash_attn;
3050 Ok(Box::new(config))
3051 }
3052}
3053
3054impl IsqModelLoader for DeepSeekV3Loader {
3055 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3056 let mut data = vec![
3057 Regex::new(r"lm_head\.(weight|bias)$")?,
3058 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3060 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3061 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3062 ];
3063 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3064 if cfg.q_lora_rank.is_some() {
3065 data.extend(vec![
3066 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3067 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3068 ]);
3069 } else {
3070 data.push(Regex::new(
3071 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
3072 )?);
3073 }
3074 for layer_idx in 0..cfg.num_hidden_layers {
3075 if cfg.n_routed_experts.is_some()
3076 && layer_idx >= cfg.first_k_dense_replace
3077 && layer_idx % cfg.moe_layer_freq == 0
3078 {
3079 for i in 0..cfg.n_routed_experts.unwrap() {
3080 data.extend(vec![
3081 Regex::new(&format!(
3082 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3083 ))?,
3084 Regex::new(&format!(
3085 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3086 ))?,
3087 Regex::new(&format!(
3088 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3089 ))?,
3090 ]);
3091 }
3092 if cfg.n_shared_experts.is_some() {
3093 data.extend(vec![
3094 Regex::new(&format!(
3095 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3096 ))?,
3097 Regex::new(&format!(
3098 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3099 ))?,
3100 Regex::new(&format!(
3101 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3102 ))?,
3103 ]);
3104 }
3105 } else {
3106 data.extend(vec![
3107 Regex::new(&format!(
3108 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3109 ))?,
3110 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3111 Regex::new(&format!(
3112 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3113 ))?,
3114 ]);
3115 };
3116 }
3117 Ok(data)
3118 }
3119
3120 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3121 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3122 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3123 for layer_idx in 0..cfg.num_hidden_layers {
3124 if cfg.n_routed_experts.is_some()
3125 && layer_idx >= cfg.first_k_dense_replace
3126 && layer_idx % cfg.moe_layer_freq == 0
3127 {
3128 for i in 0..cfg.n_routed_experts.unwrap() {
3129 data.extend(vec![
3130 Regex::new(&format!(
3131 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3132 ))?,
3133 Regex::new(&format!(
3134 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3135 ))?,
3136 Regex::new(&format!(
3137 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3138 ))?,
3139 ]);
3140 }
3141 if cfg.n_shared_experts.is_some() {
3142 data.extend(vec![
3143 Regex::new(&format!(
3144 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3145 ))?,
3146 Regex::new(&format!(
3147 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3148 ))?,
3149 Regex::new(&format!(
3150 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3151 ))?,
3152 ]);
3153 }
3154 } else {
3155 data.extend(vec![
3156 Regex::new(&format!(
3157 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3158 ))?,
3159 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
3160 Regex::new(&format!(
3161 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3162 ))?,
3163 ]);
3164 };
3165 }
3166 Ok(data)
3167 }
3168}
3169
3170impl DeviceMappedModelLoader for DeepSeekV3Loader {
3171 fn mapped_max_act_size_elems(
3172 &self,
3173 config: &str,
3174 params: &AutoDeviceMapParams,
3175 prompt_chunksize: usize,
3176 ) -> Result<usize> {
3177 let AutoDeviceMapParams::Text {
3178 max_seq_len: _,
3179 max_batch_size,
3180 } = params
3181 else {
3182 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3183 };
3184
3185 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3186
3187 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3188 }
3189 fn non_mapped_max_act_size_elems(
3190 &self,
3191 _config: &str,
3192 _params: &AutoDeviceMapParams,
3193 ) -> Result<usize> {
3194 Ok(0)
3195 }
3196
3197 fn non_mapped_size_in_bytes(
3198 &self,
3199 config: &str,
3200 dtype: DType,
3201 weight_pack_factor: usize,
3202 ) -> Result<usize> {
3203 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3204 let elems = {
3205 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3206 let lm_head = if !cfg.tie_word_embeddings {
3207 cfg.hidden_size * cfg.vocab_size
3208 } else {
3209 0
3210 };
3211 let norm = cfg.hidden_size;
3212 embed_tokens + lm_head + norm
3213 };
3214 Ok(elems * dtype.size_in_bytes())
3215 }
3216
3217 fn layer_sizes_in_bytes(
3218 &self,
3219 config: &str,
3220 dtype: DType,
3221 weight_pack_factor: usize,
3222 ) -> Result<Vec<usize>> {
3223 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3224 let mut per_layer_elems = Vec::new();
3225
3226 for layer_idx in 0..cfg.num_hidden_layers {
3227 let input_layernorm = cfg.hidden_size;
3228 let post_attention_layernorm = cfg.hidden_size;
3229
3230 let q_proj = match cfg.q_lora_rank {
3231 Some(lora_rank) => {
3232 let a = cfg.hidden_size * lora_rank;
3233 let norm = lora_rank;
3234 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
3235 a + norm + b
3236 }
3237 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
3238 };
3239 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3240 / weight_pack_factor
3241 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3242 let kv_a_layernorm = cfg.kv_lora_rank;
3243 let kv_b_proj = cfg.kv_lora_rank
3244 * cfg.num_attention_heads
3245 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3246 / weight_pack_factor;
3247 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3248 / weight_pack_factor
3249 + bias_if!(cfg.attention_bias, cfg.hidden_size);
3250
3251 let moe_block = {
3252 let mut sum = 0;
3253 if cfg.n_routed_experts.is_some()
3254 && layer_idx >= cfg.first_k_dense_replace
3255 && layer_idx % cfg.moe_layer_freq == 0
3256 {
3257 let h_size = cfg.hidden_size;
3258 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3259 * cfg.n_routed_experts.unwrap();
3260 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3261 * cfg.n_routed_experts.unwrap();
3262 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3263 * cfg.n_routed_experts.unwrap();
3264 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3265 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3266 / weight_pack_factor;
3267 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3268 / weight_pack_factor;
3269 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3270 / weight_pack_factor;
3271 gate_proj + up_proj + down_proj
3272 } else {
3273 0
3274 };
3275 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
3276 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3277 } else {
3278 let h_size = cfg.hidden_size;
3279 let i_size = cfg.intermediate_size;
3280 let gate_proj = h_size * i_size / weight_pack_factor;
3281 let up_proj = h_size * i_size / weight_pack_factor;
3282 let down_proj = i_size * h_size / weight_pack_factor;
3283 sum += gate_proj + up_proj + down_proj;
3284 }
3285 sum
3286 };
3287
3288 per_layer_elems.push(
3289 input_layernorm
3290 + post_attention_layernorm
3291 + q_proj
3292 + kv_a_layernorm
3293 + kv_a_proj_with_mqa
3294 + kv_b_proj
3295 + o_proj
3296 + moe_block,
3297 );
3298 }
3299
3300 Ok(per_layer_elems
3301 .into_iter()
3302 .map(|x| x * dtype.size_in_bytes())
3303 .collect())
3304 }
3305
3306 fn num_layers(&self, config: &str) -> Result<usize> {
3307 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3308 Ok(cfg.num_hidden_layers)
3309 }
3310
3311 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3312 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3313
3314 let cfg = ModelConfigMetadata {
3315 max_seq_len: cfg.max_position_embeddings,
3316 num_layers: cfg.num_hidden_layers,
3317 hidden_size: cfg.hidden_size,
3318 num_kv_heads: cfg.num_attention_heads,
3319 num_attn_heads: cfg.num_attention_heads,
3320 sliding_window: None,
3321 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3322 v_head_dim: cfg.v_head_dim,
3323 };
3324
3325 Ok(Box::new(cfg))
3326 }
3327}