1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
7use std::{collections::HashMap, sync::Arc};
8
9use crate::{
10 amoe::{
11 AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
12 MoeMlp,
13 },
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{
18 embedding, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19 PhiRotaryEmbedding, RmsNorm, Sdpa,
20 },
21 layers_masker::PastKvLenCache,
22 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23 pipeline::{
24 extract_logits,
25 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27 },
28 serde_default_fn,
29 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
36pub struct Config {
37 pub vocab_size: usize,
38 pub hidden_act: Activation,
39 pub hidden_size: usize,
40 pub intermediate_size: usize,
41 pub num_hidden_layers: usize,
42 pub num_attention_heads: usize,
43 pub num_key_value_heads: usize,
44 pub rms_norm_eps: f64,
45 pub rope_theta: f64,
46 pub bos_token_id: Option<u32>,
47 pub eos_token_id: Option<u32>,
48 pub rope_scaling: Option<PhiRopeScalingConfig>,
49 pub max_position_embeddings: usize,
50 pub sliding_window: Option<usize>,
51 pub original_max_position_embeddings: usize,
52 pub quantization_config: Option<QuantizedConfig>,
53 #[serde(default = "word_emb_default")]
54 pub tie_word_embeddings: bool,
55 pub partial_rotary_factor: Option<f64>,
56}
57
58impl From<Config> for PhiRopeConfig {
59 fn from(val: Config) -> Self {
60 PhiRopeConfig {
61 rope_scaling: val.rope_scaling,
62 max_position_embeddings: val.max_position_embeddings,
63 original_max_position_embeddings: val.original_max_position_embeddings,
64 rope_theta: val.rope_theta,
65 head_dim: val.hidden_size / val.num_attention_heads,
66 partial_rotary_factor: val.partial_rotary_factor,
67 }
68 }
69}
70
71impl Config {
72 pub fn head_dim(&self) -> usize {
73 self.hidden_size / self.num_attention_heads
74 }
75}
76
77struct Attention {
78 qkv_proj: Arc<dyn QuantMethod>,
79 o_proj: Arc<dyn QuantMethod>,
80 num_heads: usize,
81 num_kv_heads: usize,
82 head_dim: usize,
83 rotary_emb: Arc<PhiRotaryEmbedding>,
84 paged_attn: Option<PagedAttention>,
85 sdpa_params: SdpaParams,
86}
87
88impl Attention {
89 fn new(
90 rotary_emb: Arc<PhiRotaryEmbedding>,
91 cfg: &Config,
92 vb: ShardedVarBuilder,
93 paged_attn: Option<PagedAttention>,
94 ) -> Result<Self> {
95 let num_heads = cfg.num_attention_heads;
96 let num_kv_heads = cfg.num_key_value_heads;
97 let head_dim = cfg.head_dim();
98 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
99
100 let qkv_proj = mistralrs_quant::linear_no_bias(
102 cfg.hidden_size,
103 op_size,
104 &cfg.quantization_config,
105 vb.pp("qkv_proj"),
106 )?;
107
108 let o_proj = mistralrs_quant::linear_no_bias(
109 num_heads * head_dim,
110 cfg.hidden_size,
111 &cfg.quantization_config,
112 vb.pp("o_proj"),
113 )?;
114
115 Ok(Self {
116 qkv_proj,
117 o_proj,
118 rotary_emb,
119 num_heads,
120 num_kv_heads,
121 head_dim,
122 paged_attn,
123 sdpa_params: SdpaParams {
124 n_kv_groups: num_heads / num_kv_heads,
125 softcap: None,
126 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
127 sliding_window: cfg.sliding_window,
128 },
129 })
130 }
131
132 #[allow(clippy::too_many_arguments)]
133 fn forward(
134 &self,
135 xs: &Tensor,
136 attention_mask: Option<&Tensor>,
137 seqlen_offsets: &[usize],
138 position_ids: &[usize],
139 kv_cache: &mut KvCache,
140 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
141 flash_params: &FlashParams,
142 ) -> Result<Tensor> {
143 let (b_sz, q_len, _) = xs.dims3()?;
144
145 let original_dtype = xs.dtype();
146 let mut xs = xs.clone();
147 if let Some(t) = self.qkv_proj.quantized_act_type() {
148 xs = xs.to_dtype(t)?;
149 }
150 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
151 if self.qkv_proj.quantized_act_type().is_some() {
152 qkv = qkv.to_dtype(original_dtype)?;
153 }
154 let query_pos = self.num_heads * self.head_dim;
155 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
156 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
157 let v = qkv.narrow(
158 D::Minus1,
159 query_pos + self.num_kv_heads * self.head_dim,
160 self.num_kv_heads * self.head_dim,
161 )?;
162
163 let (q, k, v) = if q_len != 1 {
164 let q = q
165 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
166 .transpose(1, 2)?;
167 let k = k
168 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
169 .transpose(1, 2)?;
170 let v = v
171 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172 .transpose(1, 2)?;
173 (q, k, v)
174 } else {
175 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
176 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
177 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178 (q, k, v)
179 };
180
181 let (q, k) = self
182 .rotary_emb
183 .forward(&q, &k, seqlen_offsets, position_ids)?;
184
185 let mut attn_output = match &self.paged_attn {
186 Some(paged_attn) => match metadata {
187 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
188 &q,
189 &k.contiguous()?,
190 &v.contiguous()?,
191 attention_mask,
192 Some(key_cache),
193 Some(value_cache),
194 input_metadata,
195 &self.sdpa_params,
196 Some(flash_params),
197 )?,
198 None => {
199 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
202 assert!(attention_mask.is_some());
204 paged_attn.forward(
205 &q,
206 &k.contiguous()?,
207 &v.contiguous()?,
208 attention_mask,
209 None,
210 None,
211 &input_metadata,
212 &self.sdpa_params,
213 Some(flash_params),
214 )?
215 }
216 },
217 _ => {
218 let (k, v) = kv_cache.append(&k, &v)?;
219
220 Sdpa.run_attention(
221 &q,
222 &k,
223 &v,
224 attention_mask,
225 Some(flash_params),
226 &self.sdpa_params,
227 )?
228 }
229 };
230
231 if let Some(t) = self.qkv_proj.quantized_act_type() {
232 attn_output = attn_output.to_dtype(t)?;
233 }
234 attn_output = if attention_mask.is_some() {
235 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236 } else {
237 attn_output.reshape((b_sz, q_len, ()))?
238 };
239 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
240 if self.qkv_proj.quantized_act_type().is_some() {
241 res = res.to_dtype(original_dtype)?;
242 }
243 Ok(res)
244 }
245}
246
247#[derive(Clone)]
248struct Mlp {
249 gate_up_proj: Arc<dyn QuantMethod>,
250 down_proj: Arc<dyn QuantMethod>,
251 act_fn: Activation,
252 i_size: usize,
253 params: Vec<usize>,
254}
255
256impl Mlp {
257 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
258 let hidden_size = cfg.hidden_size;
259 let i_size = cfg.intermediate_size;
260
261 let gate_up_proj = mistralrs_quant::linear_no_bias(
263 hidden_size,
264 2 * i_size,
265 &cfg.quantization_config,
266 vb.pp("gate_up_proj"),
267 )?;
268
269 let down_proj = mistralrs_quant::linear_no_bias(
270 i_size,
271 hidden_size,
272 &cfg.quantization_config,
273 vb.pp("down_proj"),
274 )?;
275
276 Ok(Self {
277 gate_up_proj,
278 down_proj,
279 act_fn: cfg.hidden_act,
280 i_size,
281 params: vec![hidden_size, i_size],
282 })
283 }
284}
285
286impl AnyMoeTrainableLayer for Mlp {}
287
288impl MlpLayer for Mlp {
289 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
290 let original_dtype = xs.dtype();
291 let mut xs = xs.clone();
292 if let Some(t) = self.gate_up_proj.quantized_act_type() {
293 xs = xs.to_dtype(t)?;
294 }
295 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
296 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
297 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
298 let up_states = (up_states * gate.apply(&self.act_fn))?;
299 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
300 if self.gate_up_proj.quantized_act_type().is_some() {
301 res = res.to_dtype(original_dtype)?;
302 }
303 Ok(res)
304 }
305 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
306 vec![&mut self.gate_up_proj, &mut self.down_proj]
307 }
308 fn clone(&self) -> Box<dyn MlpLayer> {
309 Box::new(Clone::clone(self))
310 }
311 fn get_params(&self) -> &[usize] {
312 &self.params
313 }
314 fn hidden_act(&self) -> Activation {
315 self.act_fn
316 }
317 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
319 let new_gate_up = if let Some(ref delta) = deltas[0] {
320 self.gate_up_proj.add_delta_w(delta)?
321 } else {
322 self.gate_up_proj.clone()
323 };
324 let new_down = if let Some(ref delta) = deltas[1] {
325 self.down_proj.add_delta_w(delta)?
326 } else {
327 self.down_proj.clone()
328 };
329
330 Ok(Box::new(Self {
331 gate_up_proj: new_gate_up,
332 down_proj: new_down,
333 act_fn: self.act_fn,
334 i_size: self.i_size,
335 params: self.params.clone(),
336 }))
337 }
338
339 fn dtype_device(&self) -> (DType, Device) {
340 self.gate_up_proj.dtype_and_device()
341 }
342}
343
344struct DecoderLayer {
345 self_attn: Attention,
346 mlp: Box<dyn MlpLayer>,
347 input_layernorm: RmsNorm,
348 post_attention_layernorm: RmsNorm,
349}
350
351impl DecoderLayer {
352 fn new(
353 rotary_emb: Arc<PhiRotaryEmbedding>,
354 cfg: &Config,
355 vb: ShardedVarBuilder,
356 mapper: &dyn DeviceMapper,
357 layer_idx: usize,
358 loading_isq: bool,
359 paged_attn: Option<PagedAttention>,
360 ) -> Result<Self> {
361 let self_attn = Attention::new(
362 rotary_emb,
363 cfg,
364 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
365 paged_attn,
366 )?;
367 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
368 let input_layernorm = RmsNorm::new(
369 cfg.hidden_size,
370 cfg.rms_norm_eps,
371 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
372 )?;
373 let post_attention_layernorm = RmsNorm::new(
374 cfg.hidden_size,
375 cfg.rms_norm_eps,
376 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
377 )?;
378 Ok(Self {
379 self_attn,
380 mlp: Box::new(mlp),
381 input_layernorm,
382 post_attention_layernorm,
383 })
384 }
385
386 #[allow(clippy::too_many_arguments)]
387 fn forward(
388 &self,
389 xs: &Tensor,
390 attention_mask: Option<&Tensor>,
391 seqlen_offsets: &[usize],
392 position_ids: &[usize],
393 kv_cache: &mut KvCache,
394 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
395 flash_params: &FlashParams,
396 ) -> Result<Tensor> {
397 let residual = xs;
398 let xs = self.input_layernorm.forward(xs)?;
399 let xs = self.self_attn.forward(
400 &xs,
401 attention_mask,
402 seqlen_offsets,
403 position_ids,
404 kv_cache,
405 metadata,
406 flash_params,
407 )?;
408 let xs = (xs + residual)?;
409 let residual = &xs;
410 let xs = self
411 .mlp
412 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
413 residual + xs
414 }
415}
416
417pub struct Model {
418 embed_tokens: candle_nn::Embedding,
419 layers: Vec<DecoderLayer>,
420 norm: RmsNorm,
421 lm_head: Arc<dyn QuantMethod>,
422 device: Device,
423 cache: EitherCache,
424 max_seq_len: usize,
425 mapper: Box<dyn DeviceMapper + Send + Sync>,
426 sliding_window: Option<usize>,
427 cfg: ModelConfigMetadata,
428}
429
430impl Model {
431 pub fn new(
432 cfg: &Config,
433 vb: ShardedVarBuilder,
434 _is_gptx: bool,
435 normal_loading_metadata: NormalLoadingMetadata,
436 attention_mechanism: AttentionImplementation,
437 ) -> Result<Self> {
438 if let Some(ref quant_cfg) = &cfg.quantization_config {
439 tracing::info!(
440 "Using {} quantization: {}.",
441 quant_cfg.name(),
442 quant_cfg.get_bits_name(&vb)
443 );
444 }
445 let mapper = normal_loading_metadata.mapper;
446 let vb_m = vb.pp("model");
447
448 let embed_tokens = embedding(
449 cfg.vocab_size,
450 cfg.hidden_size,
451 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
452 &cfg.quantization_config,
453 )?;
454 let mut ropes = HashMap::new();
455 for layer_idx in 0..cfg.num_hidden_layers {
456 let device = mapper
457 .device_for(layer_idx, false)
458 .unwrap_or(&normal_loading_metadata.real_device);
459 ropes.insert(
460 device.location(),
461 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
462 );
463 }
464 let vb_l = vb_m.pp("layers");
465 let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
466 0..cfg.num_hidden_layers,
467 "Loading repeating layers",
468 &normal_loading_metadata.multi_progress,
469 )
470 .par_iter_if_isq(|layer_idx| {
471 let device = mapper
472 .device_for(layer_idx, false)
473 .unwrap_or(&normal_loading_metadata.real_device);
474 let rotary_emb = ropes
475 .get(&device.location())
476 .expect("No RoPE for device location!")
477 .clone();
478 let paged_attn = match &attention_mechanism {
479 AttentionImplementation::Eager => None,
480 AttentionImplementation::PagedAttention => {
481 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
482 }
483 };
484 DecoderLayer::new(
485 rotary_emb.clone(),
486 cfg,
487 vb_l.pp(layer_idx),
488 &*mapper,
489 layer_idx,
490 normal_loading_metadata.loading_isq,
491 paged_attn,
492 )
493 })?;
494 let norm = RmsNorm::new(
495 cfg.hidden_size,
496 cfg.rms_norm_eps,
497 mapper.set_nm_device(vb_m.pp("norm"), false),
498 )?;
499 let lm_head = if !cfg.tie_word_embeddings {
500 ReplicatedLayer::new(
501 cfg.hidden_size,
502 cfg.vocab_size,
503 &None,
504 false,
505 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
506 )?
507 } else {
508 ReplicatedLayer::from_linear(candle_nn::Linear::new(
509 mapper.cast_nm_device(
510 embed_tokens.embeddings(),
511 normal_loading_metadata.loading_isq,
512 )?,
513 None,
514 ))?
515 };
516 Ok(Self {
517 embed_tokens,
518 layers,
519 norm,
520 lm_head,
521 device: normal_loading_metadata.real_device,
522 cache: EitherCache::Normal(NormalCache::new_sliding(
523 cfg.num_hidden_layers,
524 cfg.max_position_embeddings,
525 cfg.sliding_window,
526 )),
527 max_seq_len: cfg.max_position_embeddings,
528 sliding_window: cfg.sliding_window,
529 cfg: ModelConfigMetadata {
530 max_seq_len: cfg.max_position_embeddings,
531 num_layers: cfg.num_hidden_layers,
532 hidden_size: cfg.hidden_size,
533 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535 .max(1),
536 sliding_window: cfg.sliding_window,
537 k_head_dim: cfg.head_dim(),
538 v_head_dim: cfg.head_dim(),
539 },
540 mapper,
541 })
542 }
543
544 pub fn forward(
545 &self,
546 input_ids: &Tensor,
547 seqlen_offsets: &[usize],
548 position_ids: &[usize],
549 context_lens: Vec<(usize, usize)>,
550 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
551 flash_params: &FlashParams,
552 ) -> Result<Tensor> {
553 let mut xs = self.embed_tokens.forward(input_ids)?;
554 let cache = &mut self.cache.normal().0;
555 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
556 input_ids,
557 metadata
558 .as_ref()
559 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
560 .unwrap_or(cache as &dyn PastKvLenCache),
561 self.sliding_window,
562 xs.dtype(),
563 self.cfg.num_attn_heads,
564 )?;
565 let attention_mask = attention_mask.filter(|_| {
567 metadata
568 .as_ref()
569 .map(|(_, meta)| meta.is_first_prompt_chunk)
570 .unwrap_or(true)
571 });
572
573 for (i, layer) in self.layers.iter().enumerate() {
574 xs = self.mapper.map(xs, i)?;
575 xs = layer.forward(
576 &xs,
577 attention_mask
578 .as_ref()
579 .map(|m| m.to_device(xs.device()).unwrap())
580 .as_ref(),
581 seqlen_offsets,
582 position_ids,
583 &mut cache[i],
584 metadata
585 .as_ref()
586 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
587 flash_params,
588 )?
589 }
590 let xs = xs.to_device(&self.device)?;
591 let mut xs = xs.apply(&self.norm)?;
592 if let Some(t) = self.lm_head.quantized_act_type() {
593 xs = xs.to_dtype(t)?;
594 }
595 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
596 }
597}
598
599impl IsqModel for Model {
600 fn get_layers(
601 &mut self,
602 ) -> (
603 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
604 &dyn DeviceMapper,
605 ) {
606 let mut tensors = Vec::new();
607 tensors.push((&mut self.lm_head, None));
608 for (i, layer) in self.layers.iter_mut().enumerate() {
609 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
610 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
611 tensors.extend(
612 layer
613 .mlp
614 .get_isq_layers()
615 .into_iter()
616 .map(|m| (m, Some(i)))
617 .collect::<Vec<_>>(),
618 );
619 }
620 (tensors, &*self.mapper)
621 }
622
623 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
624 let uvb = UnVarBuilder::new();
625
626 let uvb_m = uvb.pp("model");
627 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
628 uvb_m.pp("norm").add(&self.norm);
629
630 for (layer_idx, layer) in self.layers.iter().enumerate() {
631 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
632 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
633 uvb_l
634 .pp("post_attention_layernorm")
635 .add(&layer.post_attention_layernorm);
636 }
637
638 uvb.to_safetensors()
639 }
640
641 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
642 let mut names = Vec::new();
644 names.push(None);
646 for i in 0..self.layers.len() {
647 names.push(Some(format!("blk.{i}.attn_qkv.weight")));
648 names.push(Some(format!("blk.{i}.attn_output.weight")));
649 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
650 names.push(Some(format!("blk.{i}.ffn_up.weight")));
651 names.push(Some(format!("blk.{i}.ffn_down.weight")));
652 }
653 Ok(names)
654 }
655}
656
657impl NormalModel for Model {
658 fn forward(
659 &self,
660 input_ids: &Tensor,
661 seqlen_offsets: &[usize],
662 context_lens: Vec<(usize, usize)>,
663 position_ids: Vec<usize>,
664 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
665 flash_params: &FlashParams,
666 ) -> Result<Tensor> {
667 self.forward(
668 input_ids,
669 seqlen_offsets,
670 &position_ids,
671 context_lens,
672 metadata,
673 flash_params,
674 )
675 }
676 fn xlora_forward(
677 &self,
678 _input_ids: &Tensor,
679 _input_ids_full: &Tensor,
680 _seqlen_offsets: &[usize],
681 _seqlen_offsets_full: &[usize],
682 _no_kv_cache: bool,
683 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
684 _context_lens: Vec<(usize, usize)>,
685 _position_ids: Vec<usize>,
686 _flash_params: &FlashParams,
687 _flash_params_full: &FlashParams,
688 ) -> Result<Tensor> {
689 unimplemented!()
690 }
691 fn cache(&self) -> &EitherCache {
692 &self.cache
693 }
694 fn cache_mut(&mut self) -> &mut EitherCache {
695 &mut self.cache
696 }
697 fn device(&self) -> &Device {
698 &self.device
699 }
700 fn is_xlora(&self) -> bool {
701 false
702 }
703 fn max_seq_len(&self) -> usize {
704 self.max_seq_len
705 }
706 fn config(&self) -> &ModelConfigMetadata {
707 &self.cfg
708 }
709}
710
711impl AnyMoeBaseModelMixin for Model {
712 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
713 let mut mlps = Vec::new();
714 for layer in &self.layers {
715 mlps.push(&*layer.mlp);
716 }
717 mlps
718 }
719 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
720 let mut mlps = Vec::new();
721 for layer in &mut self.layers {
722 mlps.push(&mut layer.mlp);
723 }
724 mlps
725 }
726 fn create_anymoe_layers(
727 &mut self,
728 additional_vbs: Vec<ShardedVarBuilder>,
729 config: AnyMoeConfig,
730 (prefix, mlp): (String, String),
731 mut layers: Vec<usize>,
732 expert_type: AnyMoeExpertType,
733 gate_vb: Option<ShardedVarBuilder>,
734 ) -> Result<()> {
735 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
736 if layers.is_empty() {
737 layers = (0..self.layers.len()).collect::<Vec<_>>();
738 }
739 for _ in 0..layers.len() {
740 experts.push(Vec::new());
741 }
742 for vb in additional_vbs {
743 let vb = vb.pp(&prefix);
744 for (layer, row) in experts.iter_mut().enumerate() {
745 if !layers.contains(&layer) {
746 continue;
747 }
748
749 let intermediate_size = self.layers[layer].mlp.get_params()[1];
750 let hidden_size = self.layers[layer].mlp.get_params()[0];
751 match expert_type {
752 AnyMoeExpertType::FineTuned => {
753 let (dtype, device) = self.layers[layer].mlp.dtype_device();
754 row.push(Box::new(Mlp::new(
755 &Config {
756 intermediate_size: self.layers[layer].mlp.get_params()[1],
757 hidden_size: self.layers[layer].mlp.get_params()[0],
758 ..Default::default()
759 },
760 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
761 )?));
762 }
763 AnyMoeExpertType::LoraAdapter {
764 rank,
765 alpha,
766 ref target_modules,
767 } => {
768 let vb_mlp = vb.pp(layer).pp(&mlp);
769
770 let gate_up_proj_delta =
771 if target_modules.contains(&"gate_up_proj".to_string()) {
772 Some(get_delta_from_lora_ab!(
773 vb_mlp,
774 rank,
775 alpha,
776 (hidden_size, 2 * intermediate_size),
777 "gate_up_proj"
778 ))
779 } else {
780 None
781 };
782 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
783 Some(get_delta_from_lora_ab!(
784 vb_mlp,
785 rank,
786 alpha,
787 (hidden_size, intermediate_size),
788 "down_proj"
789 ))
790 } else {
791 None
792 };
793
794 row.push(
795 self.layers[layer]
796 .mlp
797 .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
798 );
799 }
800 }
801 }
802 }
803 for (layer, expert) in layers.into_iter().zip(experts) {
804 let mut experts_all = vec![self.layers[layer].mlp.clone()];
805 experts_all.extend(expert);
806 let (dtype, device) = self.layers[layer].mlp.dtype_device();
807 self.layers[layer].mlp = Box::new(MoeMlp::new(
808 experts_all,
809 config.clone(),
810 dtype,
811 &device,
812 layer,
813 gate_vb.as_ref(),
814 )?);
815 }
816 Ok(())
817 }
818 fn amoe_supported(&self) -> bool {
819 true
820 }
821}