1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor};
10use candle_nn::{Embedding, LayerNorm};
11use mistralrs_quant::{
12 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
13 ShardedVarBuilder,
14};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 amoe::{
19 AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
20 MoeMlp,
21 },
22 attention::SdpaParams,
23 device_map::DeviceMapper,
24 get_delta_from_lora_ab,
25 layers::{embedding, layer_norm, Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
26 layers_masker::PastKvLenCache,
27 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
28 pipeline::{
29 extract_logits,
30 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
31 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
32 },
33 serde_default_fn,
34 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
35};
36
37serde_default_fn!(bool, word_emb_default, false);
38
39#[derive(Debug, Clone, Deserialize, Default, Serialize)]
41pub struct Config {
42 pub(crate) vocab_size: usize,
43 pub(crate) hidden_size: usize,
44 pub(crate) intermediate_size: usize,
45 pub(crate) num_hidden_layers: usize,
46 pub(crate) num_attention_heads: usize,
47 pub(crate) num_key_value_heads: Option<usize>,
48 pub(crate) hidden_act: Activation,
49 pub(crate) max_position_embeddings: usize,
50 pub(crate) layer_norm_eps: f64,
51 pub(crate) rope_theta: f32,
52 pub(crate) partial_rotary_factor: f64,
53 pub(crate) qk_layernorm: bool,
54 pub(crate) use_flash_attn: bool,
55 pub(crate) quantization_config: Option<QuantizedConfig>,
56 #[serde(default = "word_emb_default")]
57 pub(crate) tie_word_embeddings: bool,
58}
59
60impl Config {
61 pub(crate) fn num_key_value_heads(&self) -> usize {
62 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
63 }
64
65 pub(crate) fn head_dim(&self) -> usize {
66 self.hidden_size / self.num_attention_heads
67 }
68}
69
70#[derive(Clone)]
71#[allow(clippy::upper_case_acronyms)]
72struct MLP {
73 fc1: Arc<dyn QuantMethod>,
74 fc2: Arc<dyn QuantMethod>,
75 act: Activation,
76 params: Vec<usize>,
77}
78
79impl MLP {
80 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
81 let fc1 = ColumnParallelLayer::new(
82 cfg.hidden_size,
83 cfg.intermediate_size,
84 &cfg.quantization_config,
85 true,
86 comm,
87 vb.pp("fc1"),
88 )?;
89 let fc2 = RowParallelLayer::new(
90 cfg.intermediate_size,
91 cfg.hidden_size,
92 &cfg.quantization_config,
93 true,
94 comm,
95 vb.pp("fc2"),
96 )?;
97 Ok(Self {
98 fc1,
99 fc2,
100 act: cfg.hidden_act,
103 params: vec![cfg.hidden_size, cfg.intermediate_size],
104 })
105 }
106}
107
108impl AnyMoeTrainableLayer for MLP {}
109
110impl MlpLayer for MLP {
111 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112 let original_dtype = xs.dtype();
113 let mut xs = xs.clone();
114 if let Some(t) = self.fc1.quantized_act_type() {
115 xs = xs.to_dtype(t)?;
116 }
117 let mut res = MatMul.qmethod_matmul(
118 &MatMul.qmethod_matmul(&xs, &*self.fc1)?.apply(&self.act)?,
119 &*self.fc2,
120 )?;
121 if self.fc1.quantized_act_type().is_some() {
122 res = res.to_dtype(original_dtype)?;
123 }
124 Ok(res)
125 }
126 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
127 vec![&mut self.fc1, &mut self.fc2]
128 }
129 fn clone(&self) -> Box<dyn MlpLayer> {
130 Box::new(Clone::clone(self))
131 }
132 fn get_params(&self) -> &[usize] {
133 &self.params
134 }
135 fn hidden_act(&self) -> Activation {
136 self.act
137 }
138 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
140 let new_fc1 = if let Some(ref delta) = deltas[0] {
141 self.fc1.add_delta_w(delta)?
142 } else {
143 self.fc1.clone()
144 };
145 let new_fc2 = if let Some(ref delta) = deltas[1] {
146 self.fc2.add_delta_w(delta)?
147 } else {
148 self.fc2.clone()
149 };
150
151 Ok(Box::new(Self {
152 fc1: new_fc1,
153 fc2: new_fc2,
154 act: self.act,
155 params: self.params.clone(),
156 }))
157 }
158
159 fn dtype_device(&self) -> (DType, Device) {
160 self.fc1.dtype_and_device()
161 }
162}
163
164struct Attention {
165 q_proj: Arc<dyn QuantMethod>,
166 k_proj: Arc<dyn QuantMethod>,
167 v_proj: Arc<dyn QuantMethod>,
168 dense: Arc<dyn QuantMethod>,
169 q_layernorm: Option<LayerNorm>,
170 k_layernorm: Option<LayerNorm>,
171 rotary_emb: Arc<RotaryEmbedding>,
172 num_heads: usize,
173 num_kv_heads: usize,
174 head_dim: usize,
175 paged_attn: Option<PagedAttention>,
176 sdpa_params: SdpaParams,
177}
178
179impl Attention {
180 fn new(
181 cfg: &Config,
182 vb: ShardedVarBuilder,
183 rope: Arc<RotaryEmbedding>,
184 paged_attn: Option<PagedAttention>,
185 comm: &Arc<mistralrs_quant::Comm>,
186 ) -> Result<Self> {
187 let num_heads = cfg.num_attention_heads;
188 let num_kv_heads = cfg.num_key_value_heads();
189 let head_dim = cfg.head_dim();
190 let q_proj = ColumnParallelLayer::new(
191 cfg.hidden_size,
192 num_heads * head_dim,
193 &cfg.quantization_config,
194 true,
195 comm,
196 vb.pp("q_proj"),
197 )?;
198 let kv_shard = mistralrs_quant::compute_kv_shard(
199 cfg.num_attention_heads,
200 cfg.hidden_size / cfg.num_attention_heads,
201 comm,
202 );
203 let k_proj = ColumnParallelLayer::new_with_shard(
204 cfg.hidden_size,
205 num_kv_heads * head_dim,
206 &cfg.quantization_config,
207 true,
208 comm,
209 kv_shard,
210 vb.pp("k_proj"),
211 )?;
212 let v_proj = ColumnParallelLayer::new_with_shard(
213 cfg.hidden_size,
214 num_kv_heads * head_dim,
215 &cfg.quantization_config,
216 true,
217 comm,
218 kv_shard,
219 vb.pp("v_proj"),
220 )?;
221 let dense = RowParallelLayer::new(
222 num_heads * head_dim,
223 cfg.hidden_size,
224 &cfg.quantization_config,
225 true,
226 comm,
227 vb.pp("dense"),
228 )?;
229 let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
230 let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
231 let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
232 (Some(q_layernorm), Some(k_layernorm))
233 } else {
234 (None, None)
235 };
236 Ok(Self {
237 q_proj,
238 k_proj,
239 v_proj,
240 dense,
241 q_layernorm,
242 k_layernorm,
243 rotary_emb: rope,
244 num_heads: num_heads / comm.world_size(),
245 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
246 head_dim,
247 paged_attn,
248 sdpa_params: SdpaParams {
249 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
250 cfg.num_attention_heads,
251 cfg.num_attention_heads,
252 comm,
253 ),
254 use_flash_attn: cfg.use_flash_attn,
255 softcap: None,
256 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
257 sliding_window: None,
258 },
259 })
260 }
261
262 #[allow(clippy::too_many_arguments)]
263 fn forward(
264 &self,
265 xs: &Tensor,
266 mask: Option<&Tensor>,
267 seqlen_offsets: &[usize],
268 kv_cache: &mut KvCache,
269 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
270 flash_params: &FlashParams,
271 ) -> Result<Tensor> {
272 let (b_size, seq_len, _n_embd) = xs.dims3()?;
273
274 let original_dtype = xs.dtype();
275 let mut xs = xs.clone();
276 if let Some(t) = self.q_proj.quantized_act_type() {
277 xs = xs.to_dtype(t)?;
278 }
279 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
280 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
281 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
282 if self.q_proj.quantized_act_type().is_some() {
283 q = q.to_dtype(original_dtype)?;
284 k = k.to_dtype(original_dtype)?;
285 v = v.to_dtype(original_dtype)?;
286 }
287
288 let q = match &self.q_layernorm {
289 None => q,
290 Some(ln) => q.apply(ln)?,
291 };
292 let k = match &self.k_layernorm {
293 None => k,
294 Some(ln) => k.apply(ln)?,
295 };
296
297 let (q, k, v) = if seq_len != 1 {
298 let q = q
299 .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
300 .transpose(1, 2)?;
301 let k = k
302 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
303 .transpose(1, 2)?;
304 let v = v
305 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
306 .transpose(1, 2)?;
307 (q, k, v)
308 } else {
309 let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
310 let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
311 let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
312 (q, k, v)
313 };
314
315 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
316
317 let mut attn_output = match &self.paged_attn {
318 Some(paged_attn) => match metadata {
319 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
320 &q,
321 &k,
322 &v,
323 mask,
324 Some(key_cache),
325 Some(value_cache),
326 input_metadata,
327 &self.sdpa_params,
328 Some(flash_params),
329 )?,
330 None => {
331 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
334 assert!(mask.is_some());
336 paged_attn.forward(
337 &q,
338 &k,
339 &v,
340 mask,
341 None,
342 None,
343 &input_metadata,
344 &self.sdpa_params,
345 Some(flash_params),
346 )?
347 }
348 },
349 None => {
350 let (k, v) = kv_cache.append(&k, &v)?;
351
352 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
353 }
354 };
355
356 if let Some(t) = self.q_proj.quantized_act_type() {
357 attn_output = attn_output.to_dtype(t)?;
358 }
359 attn_output = if mask.is_some() {
360 attn_output
361 .transpose(1, 2)?
362 .reshape((b_size, seq_len, ()))?
363 } else {
364 attn_output.reshape((b_size, seq_len, ()))?
365 };
366 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.dense)?;
367 if self.q_proj.quantized_act_type().is_some() {
368 res = res.to_dtype(original_dtype)?;
369 }
370 Ok(res)
371 }
372}
373
374struct DecoderLayer {
375 self_attn: Attention,
376 mlp: Box<dyn MlpLayer>,
377 input_layernorm: LayerNorm,
378}
379
380impl DecoderLayer {
381 #[allow(clippy::too_many_arguments)]
382 fn new(
383 cfg: &Config,
384 vb: ShardedVarBuilder,
385 mapper: &dyn DeviceMapper,
386 layer_idx: usize,
387 loading_isq: bool,
388 rotary_emb: Arc<RotaryEmbedding>,
389 paged_attn: Option<PagedAttention>,
390 comm: &Arc<mistralrs_quant::Comm>,
391 ) -> Result<Self> {
392 let self_attn = Attention::new(
393 cfg,
394 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
395 rotary_emb,
396 paged_attn,
397 comm,
398 )?;
399 let mlp = MLP::new(
400 cfg,
401 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
402 comm,
403 )?;
404 let input_layernorm = layer_norm(
405 cfg.hidden_size,
406 cfg.layer_norm_eps,
407 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
408 )?;
409 Ok(Self {
410 self_attn,
411 mlp: Box::new(mlp),
412 input_layernorm,
413 })
414 }
415
416 #[allow(clippy::too_many_arguments)]
417 fn forward(
418 &self,
419 xs: &Tensor,
420 mask: Option<&Tensor>,
421 seqlen_offsets: &[usize],
422 kv_cache: &mut KvCache,
423 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
424 flash_params: &FlashParams,
425 ) -> Result<Tensor> {
426 let residual = xs;
427 let xs = xs.apply(&self.input_layernorm)?;
428 let attn_outputs =
429 self.self_attn
430 .forward(&xs, mask, seqlen_offsets, kv_cache, metadata, flash_params)?;
431 let feed_forward_hidden_states = self.mlp.forward(&xs)?;
432 attn_outputs + feed_forward_hidden_states + residual
433 }
434}
435
436pub struct Model {
437 embed_tokens: Embedding,
438 layers: Vec<DecoderLayer>,
439 final_layernorm: LayerNorm,
440 lm_head: Arc<dyn QuantMethod>,
441 cache: EitherCache,
442 device: Device,
443 max_seq_len: usize,
444 mapper: Box<dyn DeviceMapper + Send + Sync>,
445 cfg: ModelConfigMetadata,
446}
447
448impl Model {
449 pub fn new(
450 cfg: &Config,
451 vb: ShardedVarBuilder,
452 is_gptx: bool,
453 normal_loading_metadata: NormalLoadingMetadata,
454 attention_mechanism: AttentionImplementation,
455 ) -> Result<Self> {
456 if let Some(ref quant_cfg) = &cfg.quantization_config {
457 tracing::info!(
458 "Using {} quantization: {}.",
459 quant_cfg.quant_method.to_string(),
460 quant_cfg.get_bits_name(&vb)
461 );
462 }
463 let mapper = normal_loading_metadata.mapper;
464 let vb_m = vb.pp("model");
465
466 let embed_tokens = embedding(
467 cfg.vocab_size,
468 cfg.hidden_size,
469 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
470 )?;
471 let final_layernorm = layer_norm(
472 cfg.hidden_size,
473 cfg.layer_norm_eps,
474 mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
475 )?;
476 let mut ropes = HashMap::new();
477 for layer_idx in 0..cfg.num_hidden_layers {
478 let device = mapper
479 .device_for(layer_idx, false)
480 .unwrap_or(&normal_loading_metadata.real_device);
481 ropes.insert(
483 device.location(),
484 Arc::new(RotaryEmbedding::new_partial(
485 cfg.rope_theta,
486 (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
487 cfg.max_position_embeddings,
488 device,
489 is_gptx,
490 vb.dtype(),
491 )?),
492 );
493 }
494 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
495 let vb_m = vb_m.pp("layers");
496 for layer_idx in NiceProgressBar::<_, 'b'>(
497 0..cfg.num_hidden_layers,
498 "Loading repeating layers",
499 &normal_loading_metadata.multi_progress,
500 ) {
501 let device = mapper
502 .device_for(layer_idx, false)
503 .unwrap_or(&normal_loading_metadata.real_device);
504 let rotary_emb = ropes
505 .get(&device.location())
506 .expect("No RoPE for device location!")
507 .clone();
508 let paged_attn = match &attention_mechanism {
509 AttentionImplementation::Eager => None,
510 AttentionImplementation::PagedAttention => {
511 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
512 }
513 };
514 let comm = mapper.get_comm_for(layer_idx)?;
515 let layer = DecoderLayer::new(
516 cfg,
517 vb_m.pp(layer_idx),
518 &*mapper,
519 layer_idx,
520 normal_loading_metadata.loading_isq,
521 rotary_emb,
522 paged_attn,
523 &comm,
524 )?;
525 layers.push(layer)
526 }
527 let lm_head = if !cfg.tie_word_embeddings {
528 ReplicatedLayer::new(
529 cfg.hidden_size,
530 cfg.vocab_size,
531 &None,
532 false,
533 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
534 )?
535 } else {
536 unreachable!()
537 };
538 Ok(Self {
539 embed_tokens,
540 layers,
541 final_layernorm,
542 lm_head,
543 cache: EitherCache::Normal(NormalCache::new(
544 cfg.num_hidden_layers,
545 cfg.max_position_embeddings,
546 )),
547 device: normal_loading_metadata.real_device,
548 max_seq_len: cfg.max_position_embeddings,
549 cfg: ModelConfigMetadata {
550 max_seq_len: cfg.max_position_embeddings,
551 num_layers: cfg.num_hidden_layers,
552 hidden_size: cfg.hidden_size,
553 num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
554 .max(1),
555 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
556 sliding_window: None,
557 k_head_dim: cfg.head_dim(),
558 v_head_dim: cfg.head_dim(),
559 },
560 mapper,
561 })
562 }
563
564 pub fn forward(
565 &self,
566 input_ids: &Tensor,
567 seqlen_offsets: &[usize],
568 context_lens: Vec<(usize, usize)>,
569 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
570 flash_params: &FlashParams,
571 ) -> Result<Tensor> {
572 let mut xs = input_ids.apply(&self.embed_tokens)?;
573 let cache = &mut self.cache.normal().0;
574 let mask = CausalMasker.make_causal_mask_matrix(
575 input_ids,
576 metadata
577 .as_ref()
578 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
579 .unwrap_or(cache as &dyn PastKvLenCache),
580 xs.dtype(),
581 self.cfg.num_attn_heads,
582 )?;
583 let mask = mask.filter(|_| {
585 metadata
586 .as_ref()
587 .map(|(_, meta)| meta.is_first_prompt_chunk)
588 .unwrap_or(true)
589 });
590 for (i, layer) in self.layers.iter().enumerate() {
591 xs = self.mapper.map(xs, i)?;
592 xs = layer.forward(
593 &xs,
594 mask.as_ref()
595 .map(|m| m.to_device(xs.device()).unwrap())
596 .as_ref(),
597 seqlen_offsets,
598 &mut cache[i],
599 metadata
600 .as_ref()
601 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
602 flash_params,
603 )?;
604 }
605 let xs = xs.to_device(&self.device)?;
606 let mut xs = xs.apply(&self.final_layernorm)?;
607 if let Some(t) = self.lm_head.quantized_act_type() {
608 xs = xs.to_dtype(t)?;
609 }
610 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
611 }
612}
613
614impl IsqModel for Model {
615 fn get_layers(
616 &mut self,
617 ) -> (
618 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
619 &dyn DeviceMapper,
620 ) {
621 let mut tensors = Vec::new();
622 tensors.push((&mut self.lm_head, None));
623 for (i, layer) in self.layers.iter_mut().enumerate() {
624 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
625 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
626 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
627 tensors.push((&mut layer.self_attn.dense, Some(i)));
628 tensors.extend(
629 layer
630 .mlp
631 .get_isq_layers()
632 .into_iter()
633 .map(|m| (m, Some(i)))
634 .collect::<Vec<_>>(),
635 );
636 }
637 (tensors, &*self.mapper)
638 }
639
640 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
641 let uvb = UnVarBuilder::new();
642
643 let uvb_m = uvb.pp("model");
644 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
645 uvb_m.pp("norm").add(&self.final_layernorm);
646
647 for (layer_idx, layer) in self.layers.iter().enumerate() {
648 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
649 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
650 }
651
652 uvb.to_safetensors()
653 }
654}
655
656impl NormalModel for Model {
657 fn forward(
658 &self,
659 input_ids: &Tensor,
660 seqlen_offsets: &[usize],
661 context_lens: Vec<(usize, usize)>,
662 _position_ids: Vec<usize>,
663 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
664 flash_params: &FlashParams,
665 ) -> Result<Tensor> {
666 self.forward(
667 input_ids,
668 seqlen_offsets,
669 context_lens,
670 metadata,
671 flash_params,
672 )
673 }
674 fn xlora_forward(
675 &self,
676 _input_ids: &Tensor,
677 _input_ids_full: &Tensor,
678 _seqlen_offsets: &[usize],
679 _seqlen_offsets_full: &[usize],
680 _no_kv_cache: bool,
681 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
682 _context_lens: Vec<(usize, usize)>,
683 _position_ids: Vec<usize>,
684 _flash_params: &FlashParams,
685 _flash_params_full: &FlashParams,
686 ) -> Result<Tensor> {
687 unimplemented!()
688 }
689 fn cache(&self) -> &EitherCache {
690 &self.cache
691 }
692 fn cache_mut(&mut self) -> &mut EitherCache {
693 &mut self.cache
694 }
695 fn device(&self) -> &Device {
696 &self.device
697 }
698 fn is_xlora(&self) -> bool {
699 false
700 }
701 fn max_seq_len(&self) -> usize {
702 self.max_seq_len
703 }
704 fn config(&self) -> &ModelConfigMetadata {
705 &self.cfg
706 }
707}
708
709impl AnyMoeBaseModelMixin for Model {
710 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
711 let mut mlps = Vec::new();
712 for layer in &self.layers {
713 mlps.push(&*layer.mlp);
714 }
715 mlps
716 }
717 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
718 let mut mlps = Vec::new();
719 for layer in &mut self.layers {
720 mlps.push(&mut layer.mlp);
721 }
722 mlps
723 }
724 fn create_anymoe_layers(
725 &mut self,
726 additional_vbs: Vec<ShardedVarBuilder>,
727 config: AnyMoeConfig,
728 (prefix, mlp): (String, String),
729 mut layers: Vec<usize>,
730 expert_type: AnyMoeExpertType,
731 gate_vb: Option<ShardedVarBuilder>,
732 ) -> Result<()> {
733 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
734 if layers.is_empty() {
735 layers = (0..self.layers.len()).collect::<Vec<_>>();
736 }
737 for _ in 0..layers.len() {
738 experts.push(Vec::new());
739 }
740 for vb in additional_vbs {
741 let vb = vb.pp(&prefix);
742 for (layer, row) in experts.iter_mut().enumerate() {
743 if !layers.contains(&layer) {
744 continue;
745 }
746
747 let intermediate_size = self.layers[layer].mlp.get_params()[1];
748 let hidden_size = self.layers[layer].mlp.get_params()[0];
749 match expert_type {
750 AnyMoeExpertType::FineTuned => {
751 let (dtype, device) = self.layers[layer].mlp.dtype_device();
752 row.push(Box::new(MLP::new(
753 &Config {
754 intermediate_size: self.layers[layer].mlp.get_params()[1],
755 hidden_size: self.layers[layer].mlp.get_params()[0],
756 ..Default::default()
757 },
758 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
759 &self.mapper.get_comm_for(layer)?,
760 )?));
761 }
762 AnyMoeExpertType::LoraAdapter {
763 rank,
764 alpha,
765 ref target_modules,
766 } => {
767 let vb_mlp = vb.pp(layer).pp(&mlp);
768
769 let fc1_delta = if target_modules.contains(&"fc1".to_string()) {
770 Some(get_delta_from_lora_ab!(
771 vb_mlp,
772 rank,
773 alpha,
774 (hidden_size, intermediate_size),
775 "fc1"
776 ))
777 } else {
778 None
779 };
780 let fc2_delta = if target_modules.contains(&"fc2".to_string()) {
781 Some(get_delta_from_lora_ab!(
782 vb_mlp,
783 rank,
784 alpha,
785 (intermediate_size, hidden_size),
786 "fc2"
787 ))
788 } else {
789 None
790 };
791
792 row.push(
793 self.layers[layer]
794 .mlp
795 .new_added_delta(vec![fc1_delta, fc2_delta])?,
796 );
797 }
798 }
799 }
800 }
801 for (layer, expert) in layers.into_iter().zip(experts) {
802 let mut experts_all = vec![self.layers[layer].mlp.clone()];
803 experts_all.extend(expert);
804 let (dtype, device) = self.layers[layer].mlp.dtype_device();
805 self.layers[layer].mlp = Box::new(MoeMlp::new(
806 experts_all,
807 config.clone(),
808 dtype,
809 &device,
810 layer,
811 gate_vb.as_ref(),
812 )?);
813 }
814 Ok(())
815 }
816 fn amoe_supported(&self) -> bool {
817 true
818 }
819}