1#![allow(
3 clippy::cast_possible_truncation,
4 clippy::cast_precision_loss,
5 clippy::too_many_arguments
6)]
7
8use std::sync::Arc;
9
10use candle_core::{DType, Device, Result, Tensor};
11use candle_nn::{Embedding, Module};
12use mistralrs_quant::{
13 ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
14 UnquantLinear,
15};
16
17use crate::{
18 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
19 attention::SdpaParams,
20 device_map::DeviceMapper,
21 get_delta_from_lora_ab,
22 layers::{
23 embedding, linear_no_bias as linear, Activation, CausalMasker, MatMul, RmsNorm, Sdpa,
24 },
25 layers_masker::PastKvLenCache,
26 models::llama::Config,
27 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
28 pipeline::{
29 extract_logits,
30 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
31 IsqModel, NormalLoadingMetadata, NormalModel,
32 },
33 utils::progress::NiceProgressBar,
34 AnyMoeConfig, AnyMoeExpertType,
35};
36
37use super::{LLaVALLM, OrdinaryRoPE};
38
39struct CausalSelfAttention {
40 q_proj: Arc<dyn QuantMethod>,
41 k_proj: Arc<dyn QuantMethod>,
42 v_proj: Arc<dyn QuantMethod>,
43 o_proj: Arc<dyn QuantMethod>,
44 num_attention_heads: usize,
45 num_key_value_heads: usize,
46 head_dim: usize,
47 max_seq_len: usize,
48 paged_attn: Option<PagedAttention>,
49 sdpa_params: SdpaParams,
50}
51
52impl CausalSelfAttention {
53 fn forward(
54 &self,
55 x: &Tensor,
56 attention_mask: &Option<Tensor>,
57 seqlen_offsets: &[usize],
58 block_idx: usize,
59 kv_cache: &mut crate::pipeline::LayerCaches,
60 rope_parameter: (&Tensor, &Tensor),
61 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
62 flash_params: &FlashParams,
63 ) -> Result<Tensor> {
64 let (b_sz, seq_len, _) = x.dims3()?;
65
66 let original_dtype = x.dtype();
67 let mut x = x.clone();
68 if let Some(t) = self.q_proj.quantized_act_type() {
69 x = x.to_dtype(t)?;
70 }
71 let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
72 let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
73 let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
74 if self.q_proj.quantized_act_type().is_some() {
75 q = q.to_dtype(original_dtype)?;
76 k = k.to_dtype(original_dtype)?;
77 v = v.to_dtype(original_dtype)?;
78 }
79
80 let mut q = q
81 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
82 .transpose(1, 2)?
83 .contiguous()?;
84 let mut k = k
85 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
86 .transpose(1, 2)?
87 .contiguous()?;
88 q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
89 k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
90 let v = v
91 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
92 .transpose(1, 2)?;
93
94 let mut y = match &self.paged_attn {
95 Some(paged_attn) => match metadata {
96 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
97 &q,
98 &k,
99 &v,
100 attention_mask.clone().as_ref(),
101 Some(key_cache),
102 Some(value_cache),
103 input_metadata,
104 &self.sdpa_params,
105 Some(flash_params),
106 )?,
107 None => {
108 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
111 assert!(attention_mask.is_some());
113 paged_attn.forward(
114 &q,
115 &k,
116 &v,
117 attention_mask.clone().as_ref(),
118 None,
119 None,
120 &input_metadata,
121 &self.sdpa_params,
122 Some(flash_params),
123 )?
124 }
125 },
126 None => {
127 let (k, v) =
128 crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
129
130 Sdpa.run_attention(
131 &q,
132 &k,
133 &v,
134 attention_mask.clone().as_ref(),
135 Some(flash_params),
136 &self.sdpa_params,
137 )?
138 }
139 };
140
141 if let Some(t) = self.q_proj.quantized_act_type() {
142 y = y.to_dtype(t)?;
143 }
144 y = if attention_mask.is_some() {
145 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
146 } else {
147 y.reshape((b_sz, seq_len, ()))?
148 };
149 let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
150 if self.q_proj.quantized_act_type().is_some() {
151 res = res.to_dtype(original_dtype)?;
152 }
153 Ok(res)
154 }
155
156 fn load(
157 vb: ShardedVarBuilder,
158 cfg: &Config,
159 paged_attn: Option<PagedAttention>,
160 comm: &Arc<mistralrs_quant::Comm>,
161 ) -> Result<Self> {
162 let size_in = cfg.hidden_size;
163 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
164 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
165 let q_proj = ColumnParallelLayer::new(
166 size_in,
167 size_q,
168 &cfg.quantization_config,
169 false,
170 comm,
171 vb.pp("q_proj"),
172 )?;
173 let kv_shard = mistralrs_quant::compute_kv_shard(
174 cfg.num_key_value_heads,
175 cfg.hidden_size / cfg.num_attention_heads,
176 comm,
177 );
178 let k_proj = ColumnParallelLayer::new_with_shard(
179 size_in,
180 size_kv,
181 &cfg.quantization_config,
182 false,
183 comm,
184 kv_shard,
185 vb.pp("k_proj"),
186 )?;
187 let v_proj = ColumnParallelLayer::new_with_shard(
188 size_in,
189 size_kv,
190 &cfg.quantization_config,
191 false,
192 comm,
193 kv_shard,
194 vb.pp("v_proj"),
195 )?;
196 let o_proj = RowParallelLayer::new(
197 size_q,
198 size_in,
199 &cfg.quantization_config,
200 false,
201 comm,
202 vb.pp("o_proj"),
203 )?;
204 Ok(Self {
205 q_proj,
206 k_proj,
207 v_proj,
208 o_proj,
209 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
210 num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
211 head_dim: cfg.hidden_size / cfg.num_attention_heads,
212 max_seq_len: cfg.max_position_embeddings,
213 paged_attn,
214 sdpa_params: SdpaParams {
215 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
216 cfg.num_key_value_heads,
217 cfg.num_attention_heads,
218 comm,
219 ),
220 use_flash_attn: cfg.use_flash_attn,
221 softcap: None,
222 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
223 sliding_window: None,
224 },
225 })
226 }
227}
228#[derive(Clone)]
229struct Mlp {
230 c_fc1: Arc<dyn QuantMethod>,
231 c_fc2: Arc<dyn QuantMethod>,
232 c_proj: Arc<dyn QuantMethod>,
233 params: Vec<usize>,
234}
235
236impl Mlp {
237 fn load(
238 vb: ShardedVarBuilder,
239 cfg: &Config,
240 comm: &Arc<mistralrs_quant::Comm>,
241 ) -> Result<Self> {
242 let h_size = cfg.hidden_size;
243 let i_size = cfg.intermediate_size;
244 let c_fc1 = ColumnParallelLayer::new(
245 h_size,
246 i_size,
247 &cfg.quantization_config,
248 false,
249 comm,
250 vb.pp("gate_proj"),
251 )?;
252 let c_fc2 = ColumnParallelLayer::new(
253 h_size,
254 i_size,
255 &cfg.quantization_config,
256 false,
257 comm,
258 vb.pp("up_proj"),
259 )?;
260 let c_proj = RowParallelLayer::new(
261 i_size,
262 h_size,
263 &cfg.quantization_config,
264 false,
265 comm,
266 vb.pp("down_proj"),
267 )?;
268 Ok(Self {
269 c_fc1,
270 c_fc2,
271 c_proj,
272 params: vec![h_size, i_size],
273 })
274 }
275}
276
277impl AnyMoeTrainableLayer for Mlp {}
278
279impl MlpLayer for Mlp {
280 fn forward(&self, x: &Tensor) -> Result<Tensor> {
281 let original_dtype = x.dtype();
282 let mut x = x.clone();
283 if let Some(t) = self.c_fc1.quantized_act_type() {
284 x = x.to_dtype(t)?;
285 }
286 let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)?
287 * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?;
288 let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?;
289 if self.c_fc1.quantized_act_type().is_some() {
290 res = res.to_dtype(original_dtype)?;
291 }
292 Ok(res)
293 }
294 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
295 vec![&mut self.c_fc1, &mut self.c_fc2, &mut self.c_proj]
296 }
297 fn clone(&self) -> Box<dyn MlpLayer> {
298 Box::new(Clone::clone(self))
299 }
300 fn get_params(&self) -> &[usize] {
301 &self.params
302 }
303 fn hidden_act(&self) -> Activation {
304 Activation::Silu
305 }
306 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
308 let new_c_fc1 = if let Some(ref delta) = deltas[0] {
309 self.c_fc1.add_delta_w(delta)?
310 } else {
311 self.c_fc1.clone()
312 };
313 let new_c_fc2 = if let Some(ref delta) = deltas[1] {
314 self.c_fc2.add_delta_w(delta)?
315 } else {
316 self.c_fc2.clone()
317 };
318 let new_c_proj = if let Some(ref delta) = deltas[2] {
319 self.c_proj.add_delta_w(delta)?
320 } else {
321 self.c_proj.clone()
322 };
323
324 Ok(Box::new(Self {
325 c_fc1: new_c_fc1,
326 c_fc2: new_c_fc2,
327 c_proj: new_c_proj,
328 params: self.params.clone(),
329 }))
330 }
331
332 fn dtype_device(&self) -> (DType, Device) {
333 self.c_fc1.dtype_and_device()
334 }
335}
336
337struct Block {
338 rms_1: RmsNorm,
339 attn: CausalSelfAttention,
340 rms_2: RmsNorm,
341 mlp: Box<dyn MlpLayer>,
342 rope_parameter: (Tensor, Tensor),
343}
344
345impl Block {
346 fn forward(
347 &self,
348 x: &Tensor,
349 attention_mask: &Option<Tensor>,
350 seqlen_offsets: &[usize],
351 block_idx: usize,
352 kv_cache: &mut crate::pipeline::LayerCaches,
353 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
354 flash_params: &FlashParams,
355 ) -> Result<Tensor> {
356 let residual = x;
357 let mut x = self.rms_1.forward(x)?;
358 x = (self.attn.forward(
359 &x,
360 attention_mask,
361 seqlen_offsets,
362 block_idx,
363 kv_cache,
364 (&self.rope_parameter.0, &self.rope_parameter.1),
365 metadata,
366 flash_params,
367 )? + residual)?;
368 let residual = &x;
369 x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
370 Ok(x)
371 }
372
373 fn load(
374 vb: ShardedVarBuilder,
375 cfg: &Config,
376 mapper: &dyn DeviceMapper,
377 layer_idx: usize,
378 loading_isq: bool,
379 paged_attn: Option<PagedAttention>,
380 rope_parameter: (Tensor, Tensor),
381 comm: &Arc<mistralrs_quant::Comm>,
382 ) -> Result<Self> {
383 let attn = CausalSelfAttention::load(
384 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
385 cfg,
386 paged_attn,
387 comm,
388 )?;
389 let mlp = Mlp::load(
390 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
391 cfg,
392 comm,
393 )?;
394 let rms_1 = RmsNorm::new(
395 cfg.hidden_size,
396 cfg.rms_norm_eps,
397 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
398 )?;
399 let rms_2 = RmsNorm::new(
400 cfg.hidden_size,
401 cfg.rms_norm_eps,
402 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
403 )?;
404 Ok(Self {
405 rms_1,
406 attn,
407 rms_2,
408 mlp: Box::new(mlp),
409 rope_parameter,
410 })
411 }
412}
413
414pub struct Llama {
415 wte: Embedding,
416 blocks: Vec<Block>,
417 ln_f: RmsNorm,
418 lm_head: Arc<dyn QuantMethod>,
419 kv_cache: crate::pipeline::EitherCache,
420 device: Device,
421 mapper: Box<dyn DeviceMapper + Send + Sync>,
422 cfg: ModelConfigMetadata,
423}
424
425impl Llama {
426 pub fn forward_input(
427 &self,
428 input_ids: &Tensor,
429 seqlen_offsets: &[usize],
430 context_lens: Vec<(usize, usize)>,
431 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
432 flash_params: &FlashParams,
433 ) -> Result<Tensor> {
434 let x = self.wte.forward(input_ids)?;
435 self.forward_input_embed(
436 input_ids,
437 x,
438 seqlen_offsets,
439 context_lens,
440 metadata,
441 flash_params,
442 )
443 }
444
445 pub fn new(
446 cfg: &Config,
447 vb: ShardedVarBuilder,
448 _is_gptx: bool,
449 normal_loading_metadata: NormalLoadingMetadata,
450 attention_mechanism: AttentionImplementation,
451 ) -> Result<Self> {
452 if let Some(ref quant_cfg) = &cfg.quantization_config {
453 tracing::info!(
454 "Using {} quantization: {}.",
455 quant_cfg.name(),
456 quant_cfg.get_bits_name(&vb)
457 );
458 }
459 let mapper = normal_loading_metadata.mapper;
460 let wte = embedding(
461 cfg.vocab_size,
462 cfg.hidden_size,
463 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
464 &cfg.quantization_config,
465 )?;
466 let lm_head = linear(
467 cfg.hidden_size,
468 cfg.vocab_size,
469 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
470 )?;
471 let ln_f = RmsNorm::new(
472 cfg.hidden_size,
473 cfg.rms_norm_eps,
474 mapper.set_nm_device(vb.pp("model.norm"), false),
475 )?;
476 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
477
478 let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
479 0..cfg.num_hidden_layers,
480 "Loading repeating layers",
481 &normal_loading_metadata.multi_progress,
482 )
483 .into_iter()
484 .map(|i| {
485 let vb_m = vb.pp(format!("model.layers.{i}"));
486 let device = mapper
487 .device_for(i, false)
488 .unwrap_or(&normal_loading_metadata.real_device);
489 let rope_parameters = OrdinaryRoPE::create_parameters(
490 head_dim,
491 cfg.max_position_embeddings,
492 cfg.rope_theta,
493 vb_m.dtype(),
494 device,
495 )
496 .unwrap();
497 let paged_attn = match &attention_mechanism {
498 AttentionImplementation::Eager => None,
499 AttentionImplementation::PagedAttention => Some(
500 PagedAttention::new(head_dim, device, None)
501 .expect("Failed to create PagedAttention"),
502 ),
503 };
504 let comm = mapper.get_comm_for(i).unwrap();
505 Block::load(
506 vb_m,
507 cfg,
508 &*mapper,
509 i,
510 normal_loading_metadata.loading_isq,
511 paged_attn,
512 rope_parameters,
513 &comm,
514 )
515 .expect("Failed to load block.")
516 })
517 .collect();
518
519 Ok(Self {
520 wte,
521 blocks,
522 ln_f,
523 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
524 kv_cache: crate::pipeline::EitherCache::Full(crate::pipeline::Cache::new(
525 cfg.num_hidden_layers,
526 false,
527 )),
528 device: normal_loading_metadata.real_device,
529 cfg: ModelConfigMetadata {
530 max_seq_len: cfg.max_position_embeddings,
531 num_layers: cfg.num_hidden_layers,
532 hidden_size: cfg.hidden_size,
533 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535 .max(1),
536 sliding_window: None,
537 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
538 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
539 },
540 mapper,
541 })
542 }
543}
544
545impl IsqModel for Llama {
546 fn get_layers(
547 &mut self,
548 ) -> (
549 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
550 &dyn DeviceMapper,
551 ) {
552 let mut tensors = Vec::new();
553 tensors.push((&mut self.lm_head, None));
554 for (i, layer) in self.blocks.iter_mut().enumerate() {
555 tensors.push((&mut layer.attn.q_proj, Some(i)));
556 tensors.push((&mut layer.attn.k_proj, Some(i)));
557 tensors.push((&mut layer.attn.v_proj, Some(i)));
558 tensors.push((&mut layer.attn.o_proj, Some(i)));
559 tensors.extend(
560 layer
561 .mlp
562 .get_isq_layers()
563 .into_iter()
564 .map(|m| (m, Some(i)))
565 .collect::<Vec<_>>(),
566 );
567 }
568 (tensors, &*self.mapper)
569 }
570
571 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
572 Vec::new()
573 }
574}
575
576impl LLaVALLM for Llama {
577 fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
578 self.wte.forward(input_ids)
579 }
580 fn forward_input_embed(
581 &self,
582 input_ids: &Tensor,
583 input_embed: Tensor,
584 seqlen_offsets: &[usize],
585 context_lens: Vec<(usize, usize)>,
586 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
587 flash_params: &FlashParams,
588 ) -> Result<Tensor> {
589 let mut x = input_embed;
590 let mut cache = self.kv_cache.full().lock();
591 let mask = CausalMasker.make_causal_mask_matrix(
592 input_ids,
593 metadata
594 .as_ref()
595 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
596 .unwrap_or(&*cache as &dyn PastKvLenCache),
597 x.dtype(),
598 self.blocks[0].attn.num_attention_heads,
599 )?;
600 let mask = mask.filter(|_| {
601 metadata
602 .as_ref()
603 .map(|(_, meta)| meta.is_first_prompt_chunk)
604 .unwrap_or(true)
605 });
606 for (block_idx, block) in self.blocks.iter().enumerate() {
607 x = self.mapper.map(x, block_idx)?;
608 x = block.forward(
609 &x,
610 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
611 seqlen_offsets,
612 block_idx,
613 &mut cache,
614 metadata
615 .as_ref()
616 .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
617 flash_params,
618 )?;
619 }
620 x = x.to_device(&self.device)?;
621 x = self.ln_f.forward(&x)?;
622 if let Some(t) = self.lm_head.quantized_act_type() {
623 x = x.to_dtype(t)?;
624 }
625 let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
626 extract_logits(&xs, context_lens)
627 }
628}
629
630impl NormalModel for Llama {
631 fn forward(
632 &self,
633 input_ids: &Tensor,
634 seqlen_offsets: &[usize],
635 context_lens: Vec<(usize, usize)>,
636 _position_ids: Vec<usize>,
637 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
638 flash_params: &FlashParams,
639 ) -> Result<Tensor> {
640 self.forward_input(
641 input_ids,
642 seqlen_offsets,
643 context_lens,
644 metadata,
645 flash_params,
646 )
647 }
648 fn xlora_forward(
649 &self,
650 _input_ids: &Tensor,
651 _input_ids_full: &Tensor,
652 _seqlen_offsets: &[usize],
653 _seqlen_offsets_full: &[usize],
654 _no_kv_cache: bool,
655 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
656 _context_lens: Vec<(usize, usize)>,
657 _position_ids: Vec<usize>,
658 _flash_params: &FlashParams,
659 _flash_params_full: &FlashParams,
660 ) -> Result<Tensor> {
661 unimplemented!()
662 }
663 fn cache(&self) -> &crate::pipeline::EitherCache {
664 &self.kv_cache
665 }
666 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
667 &mut self.kv_cache
668 }
669 fn device(&self) -> &Device {
670 &self.device
671 }
672 fn is_xlora(&self) -> bool {
673 false
674 }
675 fn max_seq_len(&self) -> usize {
676 self.blocks[0].attn.max_seq_len
677 }
678 fn config(&self) -> &ModelConfigMetadata {
679 &self.cfg
680 }
681}
682
683impl AnyMoeBaseModelMixin for Llama {
684 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
685 let mut mlps = Vec::new();
686 for layer in &self.blocks {
687 mlps.push(&*layer.mlp);
688 }
689 mlps
690 }
691 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
692 let mut mlps = Vec::new();
693 for layer in &mut self.blocks {
694 mlps.push(&mut layer.mlp);
695 }
696 mlps
697 }
698 fn create_anymoe_layers(
699 &mut self,
700 additional_vbs: Vec<ShardedVarBuilder>,
701 config: AnyMoeConfig,
702 (prefix, mlp): (String, String),
703 mut layers: Vec<usize>,
704 expert_type: AnyMoeExpertType,
705 gate_vb: Option<ShardedVarBuilder>,
706 ) -> Result<()> {
707 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
708 if layers.is_empty() {
709 layers = (0..self.blocks.len()).collect::<Vec<_>>();
710 }
711 for _ in 0..layers.len() {
712 experts.push(Vec::new());
713 }
714 for vb in additional_vbs {
715 let vb = vb.pp(&prefix);
716 for (layer, row) in experts.iter_mut().enumerate() {
717 if !layers.contains(&layer) {
718 continue;
719 }
720
721 let intermediate_size = self.blocks[layer].mlp.get_params()[1];
722 let hidden_size = self.blocks[layer].mlp.get_params()[0];
723 match expert_type {
724 AnyMoeExpertType::FineTuned => {
725 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
726 row.push(Box::new(Mlp::load(
727 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
728 &Config {
729 intermediate_size: self.blocks[layer].mlp.get_params()[1],
730 hidden_size: self.blocks[layer].mlp.get_params()[0],
731 ..Default::default()
732 },
733 &self.mapper.get_comm_for(layer)?,
734 )?));
735 }
736 AnyMoeExpertType::LoraAdapter {
737 rank,
738 alpha,
739 ref target_modules,
740 } => {
741 let vb_mlp = vb.pp(layer).pp(&mlp);
742
743 let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
744 Some(get_delta_from_lora_ab!(
745 vb_mlp,
746 rank,
747 alpha,
748 (hidden_size, intermediate_size),
749 "c_fc1"
750 ))
751 } else {
752 None
753 };
754 let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
755 Some(get_delta_from_lora_ab!(
756 vb_mlp,
757 rank,
758 alpha,
759 (hidden_size, intermediate_size),
760 "c_fc2"
761 ))
762 } else {
763 None
764 };
765 let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
766 Some(get_delta_from_lora_ab!(
767 vb_mlp,
768 rank,
769 alpha,
770 (intermediate_size, hidden_size),
771 "c_proj"
772 ))
773 } else {
774 None
775 };
776
777 row.push(self.blocks[layer].mlp.new_added_delta(vec![
778 c_fc1_delta,
779 c_fc2_delta,
780 c_proj_delta,
781 ])?);
782 }
783 }
784 }
785 }
786 for (layer, expert) in layers.into_iter().zip(experts) {
787 let mut experts_all = vec![self.blocks[layer].mlp.clone()];
788 experts_all.extend(expert);
789 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
790 self.blocks[layer].mlp = Box::new(MoeMlp::new(
791 experts_all,
792 config.clone(),
793 dtype,
794 &device,
795 layer,
796 gate_vb.as_ref(),
797 )?);
798 }
799 Ok(())
800 }
801 fn amoe_supported(&self) -> bool {
802 true
803 }
804}