1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor};
10use candle_nn::{Embedding, LayerNorm};
11use mistralrs_quant::{
12 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
13 ShardedVarBuilder,
14};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 amoe::{
19 AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
20 MoeMlp,
21 },
22 attention::SdpaParams,
23 device_map::DeviceMapper,
24 get_delta_from_lora_ab,
25 layers::{embedding, layer_norm, Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
26 layers_masker::PastKvLenCache,
27 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
28 pipeline::{
29 extract_logits,
30 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
31 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
32 },
33 serde_default_fn,
34 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
35};
36
37serde_default_fn!(bool, word_emb_default, false);
38
39#[derive(Debug, Clone, Deserialize, Default, Serialize)]
41pub struct Config {
42 pub(crate) vocab_size: usize,
43 pub(crate) hidden_size: usize,
44 pub(crate) intermediate_size: usize,
45 pub(crate) num_hidden_layers: usize,
46 pub(crate) num_attention_heads: usize,
47 pub(crate) num_key_value_heads: Option<usize>,
48 pub(crate) hidden_act: Activation,
49 pub(crate) max_position_embeddings: usize,
50 pub(crate) layer_norm_eps: f64,
51 pub(crate) rope_theta: f32,
52 pub(crate) partial_rotary_factor: f64,
53 pub(crate) qk_layernorm: bool,
54 pub(crate) use_flash_attn: bool,
55 pub(crate) quantization_config: Option<QuantizedConfig>,
56 #[serde(default = "word_emb_default")]
57 pub(crate) tie_word_embeddings: bool,
58}
59
60impl Config {
61 pub(crate) fn num_key_value_heads(&self) -> usize {
62 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
63 }
64
65 pub(crate) fn head_dim(&self) -> usize {
66 self.hidden_size / self.num_attention_heads
67 }
68}
69
70#[derive(Clone)]
71#[allow(clippy::upper_case_acronyms)]
72struct MLP {
73 fc1: Arc<dyn QuantMethod>,
74 fc2: Arc<dyn QuantMethod>,
75 act: Activation,
76 params: Vec<usize>,
77}
78
79impl MLP {
80 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
81 let fc1 = ColumnParallelLayer::new(
82 cfg.hidden_size,
83 cfg.intermediate_size,
84 &cfg.quantization_config,
85 true,
86 comm,
87 vb.pp("fc1"),
88 )?;
89 let fc2 = RowParallelLayer::new(
90 cfg.intermediate_size,
91 cfg.hidden_size,
92 &cfg.quantization_config,
93 true,
94 comm,
95 vb.pp("fc2"),
96 )?;
97 Ok(Self {
98 fc1,
99 fc2,
100 act: cfg.hidden_act,
103 params: vec![cfg.hidden_size, cfg.intermediate_size],
104 })
105 }
106}
107
108impl AnyMoeTrainableLayer for MLP {}
109
110impl MlpLayer for MLP {
111 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112 let original_dtype = xs.dtype();
113 let mut xs = xs.clone();
114 if let Some(t) = self.fc1.quantized_act_type() {
115 xs = xs.to_dtype(t)?;
116 }
117 let mut res = MatMul.qmethod_matmul(
118 &MatMul.qmethod_matmul(&xs, &*self.fc1)?.apply(&self.act)?,
119 &*self.fc2,
120 )?;
121 if self.fc1.quantized_act_type().is_some() {
122 res = res.to_dtype(original_dtype)?;
123 }
124 Ok(res)
125 }
126 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
127 vec![&mut self.fc1, &mut self.fc2]
128 }
129 fn clone(&self) -> Box<dyn MlpLayer> {
130 Box::new(Clone::clone(self))
131 }
132 fn get_params(&self) -> &[usize] {
133 &self.params
134 }
135 fn hidden_act(&self) -> Activation {
136 self.act
137 }
138 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
140 let new_fc1 = if let Some(ref delta) = deltas[0] {
141 self.fc1.add_delta_w(delta)?
142 } else {
143 self.fc1.clone()
144 };
145 let new_fc2 = if let Some(ref delta) = deltas[1] {
146 self.fc2.add_delta_w(delta)?
147 } else {
148 self.fc2.clone()
149 };
150
151 Ok(Box::new(Self {
152 fc1: new_fc1,
153 fc2: new_fc2,
154 act: self.act,
155 params: self.params.clone(),
156 }))
157 }
158
159 fn dtype_device(&self) -> (DType, Device) {
160 self.fc1.dtype_and_device()
161 }
162}
163
164struct Attention {
165 q_proj: Arc<dyn QuantMethod>,
166 k_proj: Arc<dyn QuantMethod>,
167 v_proj: Arc<dyn QuantMethod>,
168 dense: Arc<dyn QuantMethod>,
169 q_layernorm: Option<LayerNorm>,
170 k_layernorm: Option<LayerNorm>,
171 rotary_emb: Arc<RotaryEmbedding>,
172 num_heads: usize,
173 num_kv_heads: usize,
174 head_dim: usize,
175 paged_attn: Option<PagedAttention>,
176 sdpa_params: SdpaParams,
177}
178
179impl Attention {
180 fn new(
181 cfg: &Config,
182 vb: ShardedVarBuilder,
183 rope: Arc<RotaryEmbedding>,
184 paged_attn: Option<PagedAttention>,
185 comm: &Arc<mistralrs_quant::Comm>,
186 ) -> Result<Self> {
187 let num_heads = cfg.num_attention_heads;
188 let num_kv_heads = cfg.num_key_value_heads();
189 let head_dim = cfg.head_dim();
190 let q_proj = ColumnParallelLayer::new(
191 cfg.hidden_size,
192 num_heads * head_dim,
193 &cfg.quantization_config,
194 true,
195 comm,
196 vb.pp("q_proj"),
197 )?;
198 let kv_shard = mistralrs_quant::compute_kv_shard(
199 cfg.num_attention_heads,
200 cfg.hidden_size / cfg.num_attention_heads,
201 comm,
202 );
203 let k_proj = ColumnParallelLayer::new_with_shard(
204 cfg.hidden_size,
205 num_kv_heads * head_dim,
206 &cfg.quantization_config,
207 true,
208 comm,
209 kv_shard,
210 vb.pp("k_proj"),
211 )?;
212 let v_proj = ColumnParallelLayer::new_with_shard(
213 cfg.hidden_size,
214 num_kv_heads * head_dim,
215 &cfg.quantization_config,
216 true,
217 comm,
218 kv_shard,
219 vb.pp("v_proj"),
220 )?;
221 let dense = RowParallelLayer::new(
222 num_heads * head_dim,
223 cfg.hidden_size,
224 &cfg.quantization_config,
225 true,
226 comm,
227 vb.pp("dense"),
228 )?;
229 let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
230 let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
231 let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
232 (Some(q_layernorm), Some(k_layernorm))
233 } else {
234 (None, None)
235 };
236 Ok(Self {
237 q_proj,
238 k_proj,
239 v_proj,
240 dense,
241 q_layernorm,
242 k_layernorm,
243 rotary_emb: rope,
244 num_heads: num_heads / comm.world_size(),
245 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
246 head_dim,
247 paged_attn,
248 sdpa_params: SdpaParams {
249 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
250 cfg.num_attention_heads,
251 cfg.num_attention_heads,
252 comm,
253 ),
254 use_flash_attn: cfg.use_flash_attn,
255 softcap: None,
256 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
257 sliding_window: None,
258 },
259 })
260 }
261
262 #[allow(clippy::too_many_arguments)]
263 fn forward(
264 &self,
265 xs: &Tensor,
266 mask: Option<&Tensor>,
267 seqlen_offsets: &[usize],
268 kv_cache: &mut KvCache,
269 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
270 flash_params: &FlashParams,
271 ) -> Result<Tensor> {
272 let (b_size, seq_len, _n_embd) = xs.dims3()?;
273
274 let original_dtype = xs.dtype();
275 let mut xs = xs.clone();
276 if let Some(t) = self.q_proj.quantized_act_type() {
277 xs = xs.to_dtype(t)?;
278 }
279 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
280 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
281 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
282 if self.q_proj.quantized_act_type().is_some() {
283 q = q.to_dtype(original_dtype)?;
284 k = k.to_dtype(original_dtype)?;
285 v = v.to_dtype(original_dtype)?;
286 }
287
288 let q = match &self.q_layernorm {
289 None => q,
290 Some(ln) => q.apply(ln)?,
291 };
292 let k = match &self.k_layernorm {
293 None => k,
294 Some(ln) => k.apply(ln)?,
295 };
296
297 let (q, k, v) = if seq_len != 1 {
298 let q = q
299 .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
300 .transpose(1, 2)?;
301 let k = k
302 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
303 .transpose(1, 2)?;
304 let v = v
305 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
306 .transpose(1, 2)?;
307 (q, k, v)
308 } else {
309 let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
310 let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
311 let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
312 (q, k, v)
313 };
314
315 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
316
317 let mut attn_output = match &self.paged_attn {
318 Some(paged_attn) => match metadata {
319 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
320 &q,
321 &k,
322 &v,
323 mask,
324 Some(key_cache),
325 Some(value_cache),
326 input_metadata,
327 &self.sdpa_params,
328 Some(flash_params),
329 )?,
330 None => {
331 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
334 assert!(mask.is_some());
336 paged_attn.forward(
337 &q,
338 &k,
339 &v,
340 mask,
341 None,
342 None,
343 &input_metadata,
344 &self.sdpa_params,
345 Some(flash_params),
346 )?
347 }
348 },
349 None => {
350 let (k, v) = kv_cache.append(&k, &v)?;
351
352 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
353 }
354 };
355
356 if let Some(t) = self.q_proj.quantized_act_type() {
357 attn_output = attn_output.to_dtype(t)?;
358 }
359 attn_output = if mask.is_some() {
360 attn_output
361 .transpose(1, 2)?
362 .reshape((b_size, seq_len, ()))?
363 } else {
364 attn_output.reshape((b_size, seq_len, ()))?
365 };
366 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.dense)?;
367 if self.q_proj.quantized_act_type().is_some() {
368 res = res.to_dtype(original_dtype)?;
369 }
370 Ok(res)
371 }
372}
373
374struct DecoderLayer {
375 self_attn: Attention,
376 mlp: Box<dyn MlpLayer>,
377 input_layernorm: LayerNorm,
378}
379
380impl DecoderLayer {
381 #[allow(clippy::too_many_arguments)]
382 fn new(
383 cfg: &Config,
384 vb: ShardedVarBuilder,
385 mapper: &dyn DeviceMapper,
386 layer_idx: usize,
387 loading_isq: bool,
388 rotary_emb: Arc<RotaryEmbedding>,
389 paged_attn: Option<PagedAttention>,
390 comm: &Arc<mistralrs_quant::Comm>,
391 ) -> Result<Self> {
392 let self_attn = Attention::new(
393 cfg,
394 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
395 rotary_emb,
396 paged_attn,
397 comm,
398 )?;
399 let mlp = MLP::new(
400 cfg,
401 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
402 comm,
403 )?;
404 let input_layernorm = layer_norm(
405 cfg.hidden_size,
406 cfg.layer_norm_eps,
407 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
408 )?;
409 Ok(Self {
410 self_attn,
411 mlp: Box::new(mlp),
412 input_layernorm,
413 })
414 }
415
416 #[allow(clippy::too_many_arguments)]
417 fn forward(
418 &self,
419 xs: &Tensor,
420 mask: Option<&Tensor>,
421 seqlen_offsets: &[usize],
422 kv_cache: &mut KvCache,
423 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
424 flash_params: &FlashParams,
425 ) -> Result<Tensor> {
426 let residual = xs;
427 let xs = xs.apply(&self.input_layernorm)?;
428 let attn_outputs =
429 self.self_attn
430 .forward(&xs, mask, seqlen_offsets, kv_cache, metadata, flash_params)?;
431 let feed_forward_hidden_states = self.mlp.forward(&xs)?;
432 attn_outputs + feed_forward_hidden_states + residual
433 }
434}
435
436pub struct Model {
437 embed_tokens: Embedding,
438 layers: Vec<DecoderLayer>,
439 final_layernorm: LayerNorm,
440 lm_head: Arc<dyn QuantMethod>,
441 cache: EitherCache,
442 device: Device,
443 max_seq_len: usize,
444 mapper: Box<dyn DeviceMapper + Send + Sync>,
445 cfg: ModelConfigMetadata,
446}
447
448impl Model {
449 pub fn new(
450 cfg: &Config,
451 vb: ShardedVarBuilder,
452 is_gptx: bool,
453 normal_loading_metadata: NormalLoadingMetadata,
454 attention_mechanism: AttentionImplementation,
455 ) -> Result<Self> {
456 if let Some(ref quant_cfg) = &cfg.quantization_config {
457 tracing::info!(
458 "Using {} quantization: {}.",
459 quant_cfg.name(),
460 quant_cfg.get_bits_name(&vb)
461 );
462 }
463 let mapper = normal_loading_metadata.mapper;
464 let vb_m = vb.pp("model");
465
466 let embed_tokens = embedding(
467 cfg.vocab_size,
468 cfg.hidden_size,
469 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
470 &cfg.quantization_config,
471 )?;
472 let final_layernorm = layer_norm(
473 cfg.hidden_size,
474 cfg.layer_norm_eps,
475 mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
476 )?;
477 let mut ropes = HashMap::new();
478 for layer_idx in 0..cfg.num_hidden_layers {
479 let device = mapper
480 .device_for(layer_idx, false)
481 .unwrap_or(&normal_loading_metadata.real_device);
482 ropes.insert(
484 device.location(),
485 Arc::new(RotaryEmbedding::new_partial(
486 cfg.rope_theta,
487 (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
488 cfg.max_position_embeddings,
489 device,
490 is_gptx,
491 vb.dtype(),
492 )?),
493 );
494 }
495 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
496 let vb_m = vb_m.pp("layers");
497 for layer_idx in NiceProgressBar::<_, 'b'>(
498 0..cfg.num_hidden_layers,
499 "Loading repeating layers",
500 &normal_loading_metadata.multi_progress,
501 ) {
502 let device = mapper
503 .device_for(layer_idx, false)
504 .unwrap_or(&normal_loading_metadata.real_device);
505 let rotary_emb = ropes
506 .get(&device.location())
507 .expect("No RoPE for device location!")
508 .clone();
509 let paged_attn = match &attention_mechanism {
510 AttentionImplementation::Eager => None,
511 AttentionImplementation::PagedAttention => {
512 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
513 }
514 };
515 let comm = mapper.get_comm_for(layer_idx)?;
516 let layer = DecoderLayer::new(
517 cfg,
518 vb_m.pp(layer_idx),
519 &*mapper,
520 layer_idx,
521 normal_loading_metadata.loading_isq,
522 rotary_emb,
523 paged_attn,
524 &comm,
525 )?;
526 layers.push(layer)
527 }
528 let lm_head = if !cfg.tie_word_embeddings {
529 ReplicatedLayer::new(
530 cfg.hidden_size,
531 cfg.vocab_size,
532 &None,
533 false,
534 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
535 )?
536 } else {
537 unreachable!()
538 };
539 Ok(Self {
540 embed_tokens,
541 layers,
542 final_layernorm,
543 lm_head,
544 cache: EitherCache::Normal(NormalCache::new(
545 cfg.num_hidden_layers,
546 cfg.max_position_embeddings,
547 )),
548 device: normal_loading_metadata.real_device,
549 max_seq_len: cfg.max_position_embeddings,
550 cfg: ModelConfigMetadata {
551 max_seq_len: cfg.max_position_embeddings,
552 num_layers: cfg.num_hidden_layers,
553 hidden_size: cfg.hidden_size,
554 num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
555 .max(1),
556 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
557 sliding_window: None,
558 k_head_dim: cfg.head_dim(),
559 v_head_dim: cfg.head_dim(),
560 },
561 mapper,
562 })
563 }
564
565 pub fn forward(
566 &self,
567 input_ids: &Tensor,
568 seqlen_offsets: &[usize],
569 context_lens: Vec<(usize, usize)>,
570 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
571 flash_params: &FlashParams,
572 ) -> Result<Tensor> {
573 let mut xs = input_ids.apply(&self.embed_tokens)?;
574 let cache = &mut self.cache.normal().0;
575 let mask = CausalMasker.make_causal_mask_matrix(
576 input_ids,
577 metadata
578 .as_ref()
579 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
580 .unwrap_or(cache as &dyn PastKvLenCache),
581 xs.dtype(),
582 self.cfg.num_attn_heads,
583 )?;
584 let mask = mask.filter(|_| {
586 metadata
587 .as_ref()
588 .map(|(_, meta)| meta.is_first_prompt_chunk)
589 .unwrap_or(true)
590 });
591 for (i, layer) in self.layers.iter().enumerate() {
592 xs = self.mapper.map(xs, i)?;
593 xs = layer.forward(
594 &xs,
595 mask.as_ref()
596 .map(|m| m.to_device(xs.device()).unwrap())
597 .as_ref(),
598 seqlen_offsets,
599 &mut cache[i],
600 metadata
601 .as_ref()
602 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
603 flash_params,
604 )?;
605 }
606 let xs = xs.to_device(&self.device)?;
607 let mut xs = xs.apply(&self.final_layernorm)?;
608 if let Some(t) = self.lm_head.quantized_act_type() {
609 xs = xs.to_dtype(t)?;
610 }
611 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
612 }
613}
614
615impl IsqModel for Model {
616 fn get_layers(
617 &mut self,
618 ) -> (
619 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
620 &dyn DeviceMapper,
621 ) {
622 let mut tensors = Vec::new();
623 tensors.push((&mut self.lm_head, None));
624 for (i, layer) in self.layers.iter_mut().enumerate() {
625 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
626 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
627 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
628 tensors.push((&mut layer.self_attn.dense, Some(i)));
629 tensors.extend(
630 layer
631 .mlp
632 .get_isq_layers()
633 .into_iter()
634 .map(|m| (m, Some(i)))
635 .collect::<Vec<_>>(),
636 );
637 }
638 (tensors, &*self.mapper)
639 }
640
641 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
642 let uvb = UnVarBuilder::new();
643
644 let uvb_m = uvb.pp("model");
645 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
646 uvb_m.pp("norm").add(&self.final_layernorm);
647
648 for (layer_idx, layer) in self.layers.iter().enumerate() {
649 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
650 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
651 }
652
653 uvb.to_safetensors()
654 }
655}
656
657impl NormalModel for Model {
658 fn forward(
659 &self,
660 input_ids: &Tensor,
661 seqlen_offsets: &[usize],
662 context_lens: Vec<(usize, usize)>,
663 _position_ids: Vec<usize>,
664 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
665 flash_params: &FlashParams,
666 ) -> Result<Tensor> {
667 self.forward(
668 input_ids,
669 seqlen_offsets,
670 context_lens,
671 metadata,
672 flash_params,
673 )
674 }
675 fn xlora_forward(
676 &self,
677 _input_ids: &Tensor,
678 _input_ids_full: &Tensor,
679 _seqlen_offsets: &[usize],
680 _seqlen_offsets_full: &[usize],
681 _no_kv_cache: bool,
682 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
683 _context_lens: Vec<(usize, usize)>,
684 _position_ids: Vec<usize>,
685 _flash_params: &FlashParams,
686 _flash_params_full: &FlashParams,
687 ) -> Result<Tensor> {
688 unimplemented!()
689 }
690 fn cache(&self) -> &EitherCache {
691 &self.cache
692 }
693 fn cache_mut(&mut self) -> &mut EitherCache {
694 &mut self.cache
695 }
696 fn device(&self) -> &Device {
697 &self.device
698 }
699 fn is_xlora(&self) -> bool {
700 false
701 }
702 fn max_seq_len(&self) -> usize {
703 self.max_seq_len
704 }
705 fn config(&self) -> &ModelConfigMetadata {
706 &self.cfg
707 }
708}
709
710impl AnyMoeBaseModelMixin for Model {
711 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
712 let mut mlps = Vec::new();
713 for layer in &self.layers {
714 mlps.push(&*layer.mlp);
715 }
716 mlps
717 }
718 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
719 let mut mlps = Vec::new();
720 for layer in &mut self.layers {
721 mlps.push(&mut layer.mlp);
722 }
723 mlps
724 }
725 fn create_anymoe_layers(
726 &mut self,
727 additional_vbs: Vec<ShardedVarBuilder>,
728 config: AnyMoeConfig,
729 (prefix, mlp): (String, String),
730 mut layers: Vec<usize>,
731 expert_type: AnyMoeExpertType,
732 gate_vb: Option<ShardedVarBuilder>,
733 ) -> Result<()> {
734 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
735 if layers.is_empty() {
736 layers = (0..self.layers.len()).collect::<Vec<_>>();
737 }
738 for _ in 0..layers.len() {
739 experts.push(Vec::new());
740 }
741 for vb in additional_vbs {
742 let vb = vb.pp(&prefix);
743 for (layer, row) in experts.iter_mut().enumerate() {
744 if !layers.contains(&layer) {
745 continue;
746 }
747
748 let intermediate_size = self.layers[layer].mlp.get_params()[1];
749 let hidden_size = self.layers[layer].mlp.get_params()[0];
750 match expert_type {
751 AnyMoeExpertType::FineTuned => {
752 let (dtype, device) = self.layers[layer].mlp.dtype_device();
753 row.push(Box::new(MLP::new(
754 &Config {
755 intermediate_size: self.layers[layer].mlp.get_params()[1],
756 hidden_size: self.layers[layer].mlp.get_params()[0],
757 ..Default::default()
758 },
759 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
760 &self.mapper.get_comm_for(layer)?,
761 )?));
762 }
763 AnyMoeExpertType::LoraAdapter {
764 rank,
765 alpha,
766 ref target_modules,
767 } => {
768 let vb_mlp = vb.pp(layer).pp(&mlp);
769
770 let fc1_delta = if target_modules.contains(&"fc1".to_string()) {
771 Some(get_delta_from_lora_ab!(
772 vb_mlp,
773 rank,
774 alpha,
775 (hidden_size, intermediate_size),
776 "fc1"
777 ))
778 } else {
779 None
780 };
781 let fc2_delta = if target_modules.contains(&"fc2".to_string()) {
782 Some(get_delta_from_lora_ab!(
783 vb_mlp,
784 rank,
785 alpha,
786 (intermediate_size, hidden_size),
787 "fc2"
788 ))
789 } else {
790 None
791 };
792
793 row.push(
794 self.layers[layer]
795 .mlp
796 .new_added_delta(vec![fc1_delta, fc2_delta])?,
797 );
798 }
799 }
800 }
801 }
802 for (layer, expert) in layers.into_iter().zip(experts) {
803 let mut experts_all = vec![self.layers[layer].mlp.clone()];
804 experts_all.extend(expert);
805 let (dtype, device) = self.layers[layer].mlp.dtype_device();
806 self.layers[layer].mlp = Box::new(MoeMlp::new(
807 experts_all,
808 config.clone(),
809 dtype,
810 &device,
811 layer,
812 gate_vb.as_ref(),
813 )?);
814 }
815 Ok(())
816 }
817 fn amoe_supported(&self) -> bool {
818 true
819 }
820}