1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::{LayerNorm, Linear};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
7 ShardedVarBuilder, UnquantLinear,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
13 attention::SdpaParams,
14 device_map::DeviceMapper,
15 get_delta_from_lora_ab,
16 layers::{embedding, layer_norm, Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
17 layers_masker::PastKvLenCache,
18 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19 pipeline::{
20 extract_logits,
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23 },
24 serde_default_fn,
25 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
26 AnyMoeConfig, AnyMoeExpertType,
27};
28
29serde_default_fn!(bool, word_emb_default, false);
30
31#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
32pub struct Config {
33 pub(crate) vocab_size: usize,
34 pub(crate) hidden_size: usize,
35 pub(crate) intermediate_size: usize,
36 pub(crate) num_hidden_layers: usize,
37 pub(crate) num_attention_heads: usize,
38 pub(crate) num_key_value_heads: usize,
39 pub(crate) hidden_act: Activation,
40 pub(crate) max_position_embeddings: usize,
41 pub(crate) norm_epsilon: f64,
42 pub(crate) rope_theta: f64,
43 pub(crate) use_bias: bool,
44 pub(crate) sliding_window: Option<usize>,
45 pub(crate) use_flash_attn: bool,
46 pub(crate) quantization_config: Option<QuantizedConfig>,
47 #[serde(default = "word_emb_default")]
48 #[allow(dead_code)]
49 pub(crate) tie_word_embeddings: bool,
50}
51
52#[derive(Clone)]
53#[allow(clippy::upper_case_acronyms)]
54struct MLP {
55 c_fc: Arc<dyn QuantMethod>,
56 c_proj: Arc<dyn QuantMethod>,
57 act: Activation,
58 params: Vec<usize>,
59}
60
61impl MLP {
62 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
63 let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
64 let c_fc = ColumnParallelLayer::new(
65 h_size,
66 i_size,
67 &cfg.quantization_config,
68 cfg.use_bias,
69 comm,
70 vb.pp("c_fc"),
71 )?;
72 let c_proj = RowParallelLayer::new(
73 i_size,
74 h_size,
75 &cfg.quantization_config,
76 cfg.use_bias,
77 comm,
78 vb.pp("c_proj"),
79 )?;
80 Ok(Self {
81 c_fc,
82 c_proj,
83 act: cfg.hidden_act,
84 params: vec![h_size, i_size],
85 })
86 }
87}
88
89impl AnyMoeTrainableLayer for MLP {}
90
91impl MlpLayer for MLP {
92 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
93 let original_dtype = xs.dtype();
94 let mut xs = xs.clone();
95 if let Some(t) = self.c_fc.quantized_act_type() {
96 xs = xs.to_dtype(t)?;
97 }
98 let mut res = MatMul.qmethod_matmul(
99 &MatMul.qmethod_matmul(&xs, &*self.c_fc)?.apply(&self.act)?,
100 &*self.c_proj,
101 )?;
102 if self.c_fc.quantized_act_type().is_some() {
103 res = res.to_dtype(original_dtype)?;
104 }
105 Ok(res)
106 }
107 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
108 vec![&mut self.c_fc, &mut self.c_proj]
109 }
110 fn clone(&self) -> Box<dyn MlpLayer> {
111 Box::new(Clone::clone(self))
112 }
113 fn get_params(&self) -> &[usize] {
114 &self.params
115 }
116 fn hidden_act(&self) -> Activation {
117 self.act
118 }
119 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
121 let new_c_fc = if let Some(ref delta) = deltas[0] {
122 self.c_fc.add_delta_w(delta)?
123 } else {
124 self.c_fc.clone()
125 };
126 let new_c_proj = if let Some(ref delta) = deltas[1] {
127 self.c_proj.add_delta_w(delta)?
128 } else {
129 self.c_proj.clone()
130 };
131
132 Ok(Box::new(Self {
133 c_fc: new_c_fc,
134 c_proj: new_c_proj,
135 act: self.act,
136 params: self.params.clone(),
137 }))
138 }
139
140 fn dtype_device(&self) -> (DType, Device) {
141 self.c_fc.dtype_and_device()
142 }
143}
144
145struct Attention {
146 q_proj: Arc<dyn QuantMethod>,
147 k_proj: Arc<dyn QuantMethod>,
148 v_proj: Arc<dyn QuantMethod>,
149 o_proj: Arc<dyn QuantMethod>,
150 num_heads: usize,
151 num_kv_heads: usize,
152 head_dim: usize,
153 rotary_emb: Arc<RotaryEmbedding>,
154 paged_attn: Option<PagedAttention>,
155 sdpa_params: SdpaParams,
156}
157
158impl Attention {
159 fn new(
160 rotary_emb: Arc<RotaryEmbedding>,
161 cfg: &Config,
162 vb: ShardedVarBuilder,
163 paged_attn: Option<PagedAttention>,
164 comm: &Arc<mistralrs_quant::Comm>,
165 ) -> Result<Self> {
166 let hidden_sz = cfg.hidden_size;
167 let num_heads = cfg.num_attention_heads;
168 let num_kv_heads = cfg.num_key_value_heads;
169 let head_dim = hidden_sz / num_heads;
170 let b = cfg.use_bias;
171 let q_proj = ColumnParallelLayer::new(
172 hidden_sz,
173 num_heads * head_dim,
174 &cfg.quantization_config,
175 b,
176 comm,
177 vb.pp("q_proj"),
178 )?;
179 let kv_shard = mistralrs_quant::compute_kv_shard(
180 cfg.num_key_value_heads,
181 cfg.hidden_size / cfg.num_attention_heads,
182 comm,
183 );
184 let k_proj = ColumnParallelLayer::new_with_shard(
185 hidden_sz,
186 num_kv_heads * head_dim,
187 &cfg.quantization_config,
188 b,
189 comm,
190 kv_shard,
191 vb.pp("k_proj"),
192 )?;
193 let v_proj = ColumnParallelLayer::new_with_shard(
194 hidden_sz,
195 num_kv_heads * head_dim,
196 &cfg.quantization_config,
197 b,
198 comm,
199 kv_shard,
200 vb.pp("v_proj"),
201 )?;
202 let o_proj = RowParallelLayer::new(
203 num_heads * head_dim,
204 hidden_sz,
205 &cfg.quantization_config,
206 b,
207 comm,
208 vb.pp("o_proj"),
209 )?;
210 Ok(Self {
211 q_proj,
212 k_proj,
213 v_proj,
214 o_proj,
215 num_heads: num_heads / comm.world_size(),
216 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
217 head_dim,
218 rotary_emb,
219 paged_attn,
220 sdpa_params: SdpaParams {
221 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
222 cfg.num_key_value_heads,
223 cfg.num_attention_heads,
224 comm,
225 ),
226 use_flash_attn: cfg.use_flash_attn,
227 softcap: None,
228 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
229 sliding_window: cfg.sliding_window,
230 },
231 })
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 fn forward(
236 &self,
237 xs: &Tensor,
238 attention_mask: Option<&Tensor>,
239 seqlen_offsets: &[usize],
240 kv_cache: &mut KvCache,
241 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
242 flash_params: &FlashParams,
243 ) -> Result<Tensor> {
244 let (b_sz, q_len, _) = xs.dims3()?;
245
246 let original_dtype = xs.dtype();
247 let mut xs = xs.clone();
248 if let Some(t) = self.q_proj.quantized_act_type() {
249 xs = xs.to_dtype(t)?;
250 }
251 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
252 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
253 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
254 if self.q_proj.quantized_act_type().is_some() {
255 q = q.to_dtype(original_dtype)?;
256 k = k.to_dtype(original_dtype)?;
257 v = v.to_dtype(original_dtype)?;
258 }
259
260 let (q, k, v) = if q_len != 1 {
261 let q = q
262 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
263 .transpose(1, 2)?;
264 let k = k
265 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
266 .transpose(1, 2)?;
267 let v = v
268 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
269 .transpose(1, 2)?;
270 (q, k, v)
271 } else {
272 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
273 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
274 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
275 (q, k, v)
276 };
277
278 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
279
280 let mut attn_output = match &self.paged_attn {
281 Some(paged_attn) => match metadata {
282 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
283 &q,
284 &k,
285 &v,
286 attention_mask,
287 Some(key_cache),
288 Some(value_cache),
289 input_metadata,
290 &self.sdpa_params,
291 Some(flash_params),
292 )?,
293 None => {
294 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
297 assert!(attention_mask.is_some());
299 paged_attn.forward(
300 &q,
301 &k,
302 &v,
303 attention_mask,
304 None,
305 None,
306 &input_metadata,
307 &self.sdpa_params,
308 Some(flash_params),
309 )?
310 }
311 },
312 None => {
313 let (k, v) = kv_cache.append(&k, &v)?;
314
315 Sdpa.run_attention(
316 &q,
317 &k,
318 &v,
319 attention_mask,
320 Some(flash_params),
321 &self.sdpa_params,
322 )?
323 }
324 };
325
326 if let Some(t) = self.q_proj.quantized_act_type() {
327 attn_output = attn_output.to_dtype(t)?;
328 }
329 attn_output = if attention_mask.is_some() {
330 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
331 } else {
332 attn_output.reshape((b_sz, q_len, ()))?
333 };
334 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
335 if self.q_proj.quantized_act_type().is_some() {
336 res = res.to_dtype(original_dtype)?;
337 }
338 Ok(res)
339 }
340}
341
342struct DecoderLayer {
343 self_attn: Attention,
344 mlp: Box<dyn MlpLayer>,
345 input_layernorm: LayerNorm,
346 post_attention_layernorm: LayerNorm,
347}
348
349impl DecoderLayer {
350 #[allow(clippy::too_many_arguments)]
351 fn new(
352 rotary_emb: Arc<RotaryEmbedding>,
353 cfg: &Config,
354 vb: ShardedVarBuilder,
355 mapper: &dyn DeviceMapper,
356 layer_idx: usize,
357 loading_isq: bool,
358 paged_attn: Option<PagedAttention>,
359 comm: &Arc<mistralrs_quant::Comm>,
360 ) -> Result<Self> {
361 let self_attn = Attention::new(
362 rotary_emb,
363 cfg,
364 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
365 paged_attn,
366 comm,
367 )?;
368 let mlp = MLP::new(
369 cfg,
370 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
371 comm,
372 )?;
373 let input_layernorm = layer_norm(
374 cfg.hidden_size,
375 cfg.norm_epsilon,
376 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
377 )?;
378 let post_attention_layernorm = layer_norm(
379 cfg.hidden_size,
380 cfg.norm_epsilon,
381 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
382 )?;
383 Ok(Self {
384 self_attn,
385 mlp: Box::new(mlp),
386 input_layernorm,
387 post_attention_layernorm,
388 })
389 }
390
391 #[allow(clippy::too_many_arguments)]
392 fn forward(
393 &self,
394 xs: &Tensor,
395 attention_mask: Option<&Tensor>,
396 seqlen_offsets: &[usize],
397 kv_cache: &mut KvCache,
398 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
399 flash_params: &FlashParams,
400 ) -> Result<Tensor> {
401 let residual = xs;
402 let xs = self.input_layernorm.forward(xs)?;
403 let xs = self.self_attn.forward(
404 &xs,
405 attention_mask,
406 seqlen_offsets,
407 kv_cache,
408 metadata,
409 flash_params,
410 )?;
411 let xs = (xs + residual)?;
412 let residual = &xs;
413 let xs = self
414 .mlp
415 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
416 residual + xs
417 }
418}
419
420pub struct Model {
421 embed_tokens: candle_nn::Embedding,
422 layers: Vec<DecoderLayer>,
423 norm: LayerNorm,
424 lm_head: Arc<dyn QuantMethod>,
425 sliding_window: Option<usize>,
426 device: Device,
427 cache: EitherCache,
428 max_seq_len: usize,
429 mapper: Box<dyn DeviceMapper + Send + Sync>,
430 cfg: ModelConfigMetadata,
431}
432
433impl Model {
434 pub fn new(
435 cfg: &Config,
436 vb: ShardedVarBuilder,
437 is_gptx: bool,
438 normal_loading_metadata: NormalLoadingMetadata,
439 attention_mechanism: AttentionImplementation,
440 ) -> Result<Self> {
441 if let Some(ref quant_cfg) = &cfg.quantization_config {
442 tracing::info!(
443 "Using {} quantization: {}.",
444 quant_cfg.name(),
445 quant_cfg.get_bits_name(&vb)
446 );
447 }
448 let mapper = normal_loading_metadata.mapper;
449 let vb_m = vb.pp("model");
450
451 let embed_tokens = embedding(
452 cfg.vocab_size,
453 cfg.hidden_size,
454 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
455 &cfg.quantization_config,
456 )?;
457 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
458 let vb_l = vb_m.pp("layers");
459 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
460
461 let mut ropes = HashMap::new();
462 for layer_idx in 0..cfg.num_hidden_layers {
463 let device = mapper
464 .device_for(layer_idx, false)
465 .unwrap_or(&normal_loading_metadata.real_device);
466 ropes.insert(
467 device.location(),
468 Arc::new(RotaryEmbedding::new(
469 cfg.rope_theta as f32,
470 head_dim,
471 cfg.max_position_embeddings,
472 device,
473 is_gptx,
474 vb_m.dtype(),
475 )?),
476 );
477 }
478
479 for layer_idx in NiceProgressBar::<_, 'b'>(
480 0..cfg.num_hidden_layers,
481 "Loading repeating layers",
482 &normal_loading_metadata.multi_progress,
483 ) {
484 let device = mapper
485 .device_for(layer_idx, false)
486 .unwrap_or(&normal_loading_metadata.real_device);
487 let rotary_emb = ropes
488 .get(&device.location())
489 .expect("No RoPE for device location!")
490 .clone();
491 let paged_attn = match &attention_mechanism {
492 AttentionImplementation::Eager => None,
493 AttentionImplementation::PagedAttention => {
494 Some(PagedAttention::new(head_dim, device, None)?)
495 }
496 };
497 let comm = mapper.get_comm_for(layer_idx)?;
498 layers.push(DecoderLayer::new(
499 rotary_emb.clone(),
500 cfg,
501 vb_l.pp(layer_idx),
502 &*mapper,
503 layer_idx,
504 normal_loading_metadata.loading_isq,
505 paged_attn,
506 &comm,
507 )?)
508 }
509 let norm = layer_norm(
510 cfg.hidden_size,
511 cfg.norm_epsilon,
512 mapper.set_nm_device(vb_m.pp("norm"), false),
513 )?;
514 let lm_head = mapper.cast_nm_device(
515 embed_tokens.embeddings(),
516 normal_loading_metadata.loading_isq,
517 )?;
518 Ok(Self {
519 embed_tokens,
520 layers,
521 norm,
522 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
523 Linear::new(lm_head, None),
524 ))?),
525 sliding_window: cfg.sliding_window,
526 device: normal_loading_metadata.real_device,
527 cache: EitherCache::Normal(NormalCache::new_sliding(
528 cfg.num_hidden_layers,
529 cfg.max_position_embeddings,
530 cfg.sliding_window,
531 )),
532 max_seq_len: cfg.max_position_embeddings,
533 cfg: ModelConfigMetadata {
534 max_seq_len: cfg.max_position_embeddings,
535 num_layers: cfg.num_hidden_layers,
536 hidden_size: cfg.hidden_size,
537 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
538 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
539 .max(1),
540 sliding_window: cfg.sliding_window,
541 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
542 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
543 },
544 mapper,
545 })
546 }
547
548 pub fn forward(
549 &self,
550 input_ids: &Tensor,
551 seqlen_offsets: &[usize],
552 context_lens: Vec<(usize, usize)>,
553 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
554 flash_params: &FlashParams,
555 ) -> Result<Tensor> {
556 let mut xs = self.embed_tokens.forward(input_ids)?;
557
558 let cache = &mut self.cache.normal().0;
559 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
560 input_ids,
561 metadata
562 .as_ref()
563 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
564 .unwrap_or(cache as &dyn PastKvLenCache),
565 self.sliding_window,
566 xs.dtype(),
567 self.cfg.num_attn_heads,
568 )?;
569 let attention_mask = attention_mask.filter(|_| {
570 metadata
571 .as_ref()
572 .map(|(_, meta)| meta.is_first_prompt_chunk)
573 .unwrap_or(true)
574 });
575
576 for (i, layer) in self.layers.iter().enumerate() {
577 xs = self.mapper.map(xs, i)?;
578 xs = layer.forward(
579 &xs,
580 attention_mask
581 .as_ref()
582 .map(|m| m.to_device(xs.device()).unwrap())
583 .as_ref(),
584 seqlen_offsets,
585 &mut cache[i],
586 metadata
587 .as_ref()
588 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
589 flash_params,
590 )?
591 }
592 let mut xs = xs.to_device(&self.device)?.apply(&self.norm)?;
593 if let Some(t) = self.lm_head.quantized_act_type() {
594 xs = xs.to_dtype(t)?;
595 }
596 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
597 }
598}
599
600impl IsqModel for Model {
601 fn get_layers(
602 &mut self,
603 ) -> (
604 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
605 &dyn DeviceMapper,
606 ) {
607 let mut tensors = Vec::new();
608 tensors.push((&mut self.lm_head, None));
609 for (i, layer) in self.layers.iter_mut().enumerate() {
610 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
611 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
612 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
613 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
614 tensors.extend(
615 layer
616 .mlp
617 .get_isq_layers()
618 .into_iter()
619 .map(|m| (m, Some(i)))
620 .collect::<Vec<_>>(),
621 );
622 }
623 (tensors, &*self.mapper)
624 }
625
626 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
627 let uvb = UnVarBuilder::new();
628
629 let uvb_m = uvb.pp("model");
630 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
631 uvb_m.pp("norm").add(&self.norm);
632
633 for (layer_idx, layer) in self.layers.iter().enumerate() {
634 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
635 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
636 uvb_l
637 .pp("post_attention_layernorm")
638 .add(&layer.post_attention_layernorm);
639 }
640
641 uvb.to_safetensors()
642 }
643}
644
645impl NormalModel for Model {
646 fn forward(
647 &self,
648 input_ids: &Tensor,
649 seqlen_offsets: &[usize],
650 context_lens: Vec<(usize, usize)>,
651 _position_ids: Vec<usize>,
652 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
653 flash_params: &FlashParams,
654 ) -> Result<Tensor> {
655 self.forward(
656 input_ids,
657 seqlen_offsets,
658 context_lens,
659 metadata,
660 flash_params,
661 )
662 }
663 fn xlora_forward(
664 &self,
665 _input_ids: &Tensor,
666 _input_ids_full: &Tensor,
667 _seqlen_offsets: &[usize],
668 _seqlen_offsets_full: &[usize],
669 _no_kv_cache: bool,
670 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
671 _context_lens: Vec<(usize, usize)>,
672 _position_ids: Vec<usize>,
673 _flash_params: &FlashParams,
674 _flash_params_full: &FlashParams,
675 ) -> Result<Tensor> {
676 unimplemented!()
677 }
678 fn cache(&self) -> &EitherCache {
679 &self.cache
680 }
681 fn cache_mut(&mut self) -> &mut EitherCache {
682 &mut self.cache
683 }
684 fn device(&self) -> &Device {
685 &self.device
686 }
687 fn is_xlora(&self) -> bool {
688 false
689 }
690 fn max_seq_len(&self) -> usize {
691 self.max_seq_len
692 }
693 fn config(&self) -> &ModelConfigMetadata {
694 &self.cfg
695 }
696}
697
698impl AnyMoeBaseModelMixin for Model {
699 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
700 let mut mlps = Vec::new();
701 for layer in &self.layers {
702 mlps.push(&*layer.mlp);
703 }
704 mlps
705 }
706 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
707 let mut mlps = Vec::new();
708 for layer in &mut self.layers {
709 mlps.push(&mut layer.mlp);
710 }
711 mlps
712 }
713 fn create_anymoe_layers(
714 &mut self,
715 additional_vbs: Vec<ShardedVarBuilder>,
716 config: AnyMoeConfig,
717 (prefix, mlp): (String, String),
718 mut layers: Vec<usize>,
719 expert_type: AnyMoeExpertType,
720 gate_vb: Option<ShardedVarBuilder>,
721 ) -> Result<()> {
722 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
723 if layers.is_empty() {
724 layers = (0..self.layers.len()).collect::<Vec<_>>();
725 }
726 for _ in 0..layers.len() {
727 experts.push(Vec::new());
728 }
729 for vb in additional_vbs {
730 let vb = vb.pp(&prefix);
731 for (layer, row) in experts.iter_mut().enumerate() {
732 if !layers.contains(&layer) {
733 continue;
734 }
735
736 let intermediate_size = self.layers[layer].mlp.get_params()[1];
737 let hidden_size = self.layers[layer].mlp.get_params()[0];
738 match expert_type {
739 AnyMoeExpertType::FineTuned => {
740 let (dtype, device) = self.layers[layer].mlp.dtype_device();
741 row.push(Box::new(MLP::new(
742 &Config {
743 intermediate_size: self.layers[layer].mlp.get_params()[1],
744 hidden_size: self.layers[layer].mlp.get_params()[0],
745 ..Default::default()
746 },
747 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
748 &self.mapper.get_comm_for(layer)?,
749 )?));
750 }
751 AnyMoeExpertType::LoraAdapter {
752 rank,
753 alpha,
754 ref target_modules,
755 } => {
756 let vb_mlp = vb.pp(layer).pp(&mlp);
757
758 let c_fc_delta = if target_modules.contains(&"c_fc".to_string()) {
759 Some(get_delta_from_lora_ab!(
760 vb_mlp,
761 rank,
762 alpha,
763 (hidden_size, intermediate_size),
764 "c_fc"
765 ))
766 } else {
767 None
768 };
769 let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
770 Some(get_delta_from_lora_ab!(
771 vb_mlp,
772 rank,
773 alpha,
774 (intermediate_size, hidden_size),
775 "c_proj"
776 ))
777 } else {
778 None
779 };
780
781 row.push(
782 self.layers[layer]
783 .mlp
784 .new_added_delta(vec![c_fc_delta, c_proj_delta])?,
785 );
786 }
787 }
788 }
789 }
790 for (layer, expert) in layers.into_iter().zip(experts) {
791 let mut experts_all = vec![self.layers[layer].mlp.clone()];
792 experts_all.extend(expert);
793 let (dtype, device) = self.layers[layer].mlp.dtype_device();
794 self.layers[layer].mlp = Box::new(MoeMlp::new(
795 experts_all,
796 config.clone(),
797 dtype,
798 &device,
799 layer,
800 gate_vb.as_ref(),
801 )?);
802 }
803 Ok(())
804 }
805 fn amoe_supported(&self) -> bool {
806 true
807 }
808}