1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7 ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{
18 embedding, Activation, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, Mlp,
19 RmsNorm, Sdpa,
20 },
21 layers_masker::PastKvLenCache,
22 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23 pipeline::{
24 extract_logits,
25 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27 },
28 serde_default_fn,
29 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33serde_default_fn!(bool, use_flash_attn_default, false);
34
35#[derive(Debug, Clone, Deserialize, Serialize, Default)]
36pub struct Config {
37 pub hidden_act: Activation,
38 pub hidden_size: usize,
39 pub intermediate_size: usize,
40 pub vocab_size: usize,
41 pub num_hidden_layers: usize,
42 pub num_attention_heads: usize,
43 pub num_key_value_heads: usize,
44 #[serde(default = "use_flash_attn_default")]
45 pub use_flash_attn: bool,
46 pub rms_norm_eps: f64,
47 pub rope_theta: f32,
48 pub max_position_embeddings: usize,
49 pub rope_scaling: Option<Llama3RopeConfig>,
50 pub quantization_config: Option<QuantizedConfig>,
51 #[serde(default = "word_emb_default")]
52 pub tie_word_embeddings: bool,
53}
54
55struct CausalSelfAttention {
56 q_proj: Arc<dyn QuantMethod>,
57 k_proj: Arc<dyn QuantMethod>,
58 v_proj: Arc<dyn QuantMethod>,
59 o_proj: Arc<dyn QuantMethod>,
60 num_attention_heads: usize,
61 num_key_value_heads: usize,
62 head_dim: usize,
63 rotary_emb: Arc<Llama3RotaryEmbedding>,
64 max_seq_len: usize,
65 paged_attn: Option<PagedAttention>,
66 sdpa_params: SdpaParams,
67}
68
69impl CausalSelfAttention {
70 #[allow(clippy::too_many_arguments)]
71 fn forward(
72 &self,
73 x: &Tensor,
74 attention_mask: &Option<Tensor>,
75 seqlen_offsets: &[usize],
76 kv_cache: &mut KvCache,
77 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
78 flash_params: &FlashParams,
79 ) -> Result<Tensor> {
80 let (b_sz, seq_len, _) = x.dims3()?;
81
82 let original_dtype = x.dtype();
83 let mut x = x.clone();
84 if let Some(t) = self.q_proj.quantized_act_type() {
85 x = x.to_dtype(t)?;
86 }
87 let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
88 let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
89 let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
90 if self.q_proj.quantized_act_type().is_some() {
91 q = q.to_dtype(original_dtype)?;
92 k = k.to_dtype(original_dtype)?;
93 v = v.to_dtype(original_dtype)?;
94 }
95
96 let (q, k, v) = if seq_len != 1 {
97 let q = q
98 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
99 .transpose(1, 2)?;
100 let k = k
101 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
102 .transpose(1, 2)?;
103 let v = v
104 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
105 .transpose(1, 2)?;
106 (q, k, v)
107 } else {
108 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
109 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
110 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
111 (q, k, v)
112 };
113
114 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
115
116 let mut y = match &self.paged_attn {
117 Some(paged_attn) => match metadata {
118 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
119 &q,
120 &k,
121 &v,
122 attention_mask.clone().as_ref(),
123 Some(key_cache),
124 Some(value_cache),
125 input_metadata,
126 &self.sdpa_params,
127 Some(flash_params),
128 )?,
129 None => {
130 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
133 assert!(attention_mask.is_some());
135 paged_attn.forward(
136 &q,
137 &k,
138 &v,
139 attention_mask.clone().as_ref(),
140 None,
141 None,
142 &input_metadata,
143 &self.sdpa_params,
144 Some(flash_params),
145 )?
146 }
147 },
148 None => {
149 let (k, v) = kv_cache.append(&k, &v)?;
150
151 Sdpa.run_attention(
152 &q,
153 &k,
154 &v,
155 attention_mask.clone().as_ref(),
156 Some(flash_params),
157 &self.sdpa_params,
158 )?
159 }
160 };
161
162 if let Some(t) = self.q_proj.quantized_act_type() {
163 y = y.to_dtype(t)?;
164 }
165 y = if attention_mask.is_some() {
166 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
167 } else {
168 y.reshape((b_sz, seq_len, ()))?
169 };
170 let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
171 if self.q_proj.quantized_act_type().is_some() {
172 res = res.to_dtype(original_dtype)?;
173 }
174 Ok(res)
175 }
176
177 fn load(
178 vb: ShardedVarBuilder,
179 cfg: &Config,
180 rope: Arc<Llama3RotaryEmbedding>,
181 paged_attn: Option<PagedAttention>,
182 comm: &Arc<mistralrs_quant::Comm>,
183 ) -> Result<Self> {
184 let size_in = cfg.hidden_size;
185 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
186 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
187 let q_proj = ColumnParallelLayer::new(
188 size_in,
189 size_q,
190 &cfg.quantization_config,
191 false,
192 comm,
193 vb.pp("q_proj"),
194 )?;
195 let kv_shard = mistralrs_quant::compute_kv_shard(
196 cfg.num_key_value_heads,
197 cfg.hidden_size / cfg.num_attention_heads,
198 comm,
199 );
200 let k_proj = ColumnParallelLayer::new_with_shard(
201 size_in,
202 size_kv,
203 &cfg.quantization_config,
204 false,
205 comm,
206 kv_shard,
207 vb.pp("k_proj"),
208 )?;
209 let v_proj = ColumnParallelLayer::new_with_shard(
210 size_in,
211 size_kv,
212 &cfg.quantization_config,
213 false,
214 comm,
215 kv_shard,
216 vb.pp("v_proj"),
217 )?;
218 let o_proj = RowParallelLayer::new(
219 size_q,
220 size_in,
221 &cfg.quantization_config,
222 false,
223 comm,
224 vb.pp("o_proj"),
225 )?;
226 Ok(Self {
227 q_proj,
228 k_proj,
229 v_proj,
230 o_proj,
231 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
232 num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
233 head_dim: cfg.hidden_size / cfg.num_attention_heads,
234 rotary_emb: rope,
235 max_seq_len: cfg.max_position_embeddings,
236 paged_attn,
237 sdpa_params: SdpaParams {
238 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
239 cfg.num_key_value_heads,
240 cfg.num_attention_heads,
241 comm,
242 ),
243 use_flash_attn: cfg.use_flash_attn,
244 softcap: None,
245 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
246 sliding_window: None,
247 },
248 })
249 }
250}
251
252struct Block {
253 rms_1: RmsNorm,
254 attn: CausalSelfAttention,
255 rms_2: RmsNorm,
256 mlp: Box<dyn MlpLayer>,
257}
258
259impl Block {
260 #[allow(clippy::too_many_arguments)]
261 fn forward(
262 &self,
263 x: &Tensor,
264 attention_mask: &Option<Tensor>,
265 seqlen_offsets: &[usize],
266 kv_cache: &mut KvCache,
267 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
268 flash_params: &FlashParams,
269 ) -> Result<Tensor> {
270 let residual = x;
271 let x = self.rms_1.forward(x)?;
272 let x = (self.attn.forward(
273 &x,
274 attention_mask,
275 seqlen_offsets,
276 kv_cache,
277 metadata,
278 flash_params,
279 )? + residual)?;
280 let residual = &x;
281 let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
282 Ok(x)
283 }
284
285 #[allow(clippy::too_many_arguments)]
286 fn load(
287 vb: ShardedVarBuilder,
288 cfg: &Config,
289 mapper: &dyn DeviceMapper,
290 layer_idx: usize,
291 loading_isq: bool,
292 rope: Arc<Llama3RotaryEmbedding>,
293 paged_attn: Option<PagedAttention>,
294 comm: &Arc<mistralrs_quant::Comm>,
295 ) -> Result<Self> {
296 let attn = CausalSelfAttention::load(
297 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298 cfg,
299 rope,
300 paged_attn,
301 comm,
302 )?;
303 let mlp = Mlp::new(
304 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
305 cfg.hidden_size,
306 cfg.intermediate_size,
307 &cfg.quantization_config,
308 cfg.hidden_act,
309 comm,
310 )?;
311 let rms_1 = RmsNorm::new(
312 cfg.hidden_size,
313 cfg.rms_norm_eps,
314 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
315 )?;
316 let rms_2 = RmsNorm::new(
317 cfg.hidden_size,
318 cfg.rms_norm_eps,
319 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
320 )?;
321 Ok(Self {
322 rms_1,
323 attn,
324 rms_2,
325 mlp: Box::new(mlp),
326 })
327 }
328}
329
330pub struct Llama {
331 wte: Embedding,
332 blocks: Vec<Block>,
333 ln_f: RmsNorm,
334 lm_head: Arc<dyn QuantMethod>,
335 kv_cache: crate::pipeline::EitherCache,
336 device: Device,
337 mapper: Box<dyn DeviceMapper + Send + Sync>,
338 cfg: ModelConfigMetadata,
339}
340
341impl Llama {
342 pub fn new(
343 cfg: &Config,
344 vb: ShardedVarBuilder,
345 is_gptx: bool,
346 normal_loading_metadata: NormalLoadingMetadata,
347 attention_mechanism: AttentionImplementation,
348 ) -> Result<Self> {
349 let vb_m = vb.pp("model");
350 let vb_lm_head = vb.pp("lm_head");
351 Self::new_inner(
352 cfg,
353 vb_m,
354 vb_lm_head,
355 is_gptx,
356 normal_loading_metadata,
357 attention_mechanism,
358 )
359 }
360
361 pub fn new_inner(
362 cfg: &Config,
363 vb_m: ShardedVarBuilder,
364 vb_lm_head: ShardedVarBuilder,
365 is_gptx: bool,
366 normal_loading_metadata: NormalLoadingMetadata,
367 attention_mechanism: AttentionImplementation,
368 ) -> Result<Self> {
369 if let Some(ref quant_cfg) = &cfg.quantization_config {
370 tracing::info!(
371 "Using {} quantization: {}.",
372 quant_cfg.name(),
373 quant_cfg.get_bits_name(&vb_m)
374 );
375 }
376 let mapper = normal_loading_metadata.mapper;
377
378 let wte = embedding(
379 cfg.vocab_size,
380 cfg.hidden_size,
381 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
382 &cfg.quantization_config,
383 )?;
384 let lm_head = if !cfg.tie_word_embeddings {
385 ReplicatedLayer::new(
386 cfg.hidden_size,
387 cfg.vocab_size,
388 &None,
389 false,
390 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
391 )?
392 } else {
393 ReplicatedLayer::from_linear(candle_nn::Linear::new(
394 mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
395 None,
396 ))?
397 };
398 let ln_f = RmsNorm::new(
399 cfg.hidden_size,
400 cfg.rms_norm_eps,
401 mapper.set_nm_device(vb_m.pp("norm"), false),
402 )?;
403 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
404 let mut ropes = HashMap::new();
405 for i in 0..cfg.num_hidden_layers {
406 let device = mapper
407 .device_for(i, false)
408 .unwrap_or(&normal_loading_metadata.real_device);
409 ropes.insert(
410 device.location(),
411 Arc::new(Llama3RotaryEmbedding::new_llama3(
412 vb_m.dtype(),
413 cfg,
414 device,
415 is_gptx,
416 )?),
417 );
418 }
419 let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
420 0..cfg.num_hidden_layers,
421 "Loading repeating layers",
422 &normal_loading_metadata.multi_progress,
423 )
424 .into_iter()
425 .map(|i| {
426 let device = mapper
427 .device_for(i, false)
428 .unwrap_or(&normal_loading_metadata.real_device);
429 let rotary_emb = ropes
430 .get(&device.location())
431 .expect("No RoPE for device location!")
432 .clone();
433 let paged_attn = match &attention_mechanism {
434 AttentionImplementation::Eager => None,
435 AttentionImplementation::PagedAttention => Some(
436 PagedAttention::new(head_dim, device, None)
437 .expect("Failed to create PagedAttention"),
438 ),
439 };
440 let comm = mapper.get_comm_for(i).unwrap();
441 Block::load(
442 vb_m.pp(format!("layers.{i}")),
443 cfg,
444 &*mapper,
445 i,
446 normal_loading_metadata.loading_isq,
447 rotary_emb,
448 paged_attn,
449 &comm,
450 )
451 .expect("Failed to load block.")
452 })
453 .collect();
454
455 Ok(Self {
456 wte,
457 blocks,
458 ln_f,
459 lm_head,
460 kv_cache: EitherCache::Normal(NormalCache::new(
461 cfg.num_hidden_layers,
462 cfg.max_position_embeddings,
463 )),
464 device: normal_loading_metadata.real_device,
465 cfg: ModelConfigMetadata {
466 max_seq_len: cfg.max_position_embeddings,
467 num_layers: cfg.num_hidden_layers,
468 hidden_size: cfg.hidden_size,
469 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
470 .max(1),
471 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
472 sliding_window: None,
473 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
474 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
475 },
476 mapper,
477 })
478 }
479
480 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
481 self.wte.forward(input_ids)
482 }
483
484 pub fn forward(
485 &self,
486 input_ids: &Tensor,
487 seqlen_offsets: &[usize],
488 context_lens: Vec<(usize, usize)>,
489 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
490 flash_params: &FlashParams,
491 ) -> Result<Tensor> {
492 self.forward_embeds(
493 input_ids,
494 self.wte.forward(input_ids)?,
495 seqlen_offsets,
496 context_lens,
497 metadata,
498 flash_params,
499 )
500 }
501
502 #[allow(clippy::too_many_arguments)]
503 pub fn forward_embeds(
504 &self,
505 input_ids: &Tensor,
506 input_embeds: Tensor,
507 seqlen_offsets: &[usize],
508 context_lens: Vec<(usize, usize)>,
509 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
510 flash_params: &FlashParams,
511 ) -> Result<Tensor> {
512 let mut x = input_embeds;
513 let cache = &mut self.kv_cache.normal().0;
514 let mask = CausalMasker.make_causal_mask_matrix(
515 input_ids,
516 metadata
517 .as_ref()
518 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
519 .unwrap_or(cache as &dyn PastKvLenCache),
520 x.dtype(),
521 self.blocks[0].attn.num_attention_heads,
522 )?;
523 let mask = mask.filter(|_| {
525 metadata
526 .as_ref()
527 .map(|(_, meta)| meta.is_first_prompt_chunk)
528 .unwrap_or(true)
529 });
530 for (block_idx, block) in self.blocks.iter().enumerate() {
531 x = self.mapper.map(x, block_idx)?;
532 x = block.forward(
533 &x,
534 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
535 seqlen_offsets,
536 &mut cache[block_idx],
537 metadata
538 .as_ref()
539 .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
540 flash_params,
541 )?;
542 }
543 let x = x.to_device(&self.device)?;
544 let mut x = self.ln_f.forward(&x)?;
545 if let Some(t) = self.lm_head.quantized_act_type() {
546 x = x.to_dtype(t)?;
547 }
548 let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
549 extract_logits(&xs, context_lens)
550 }
551
552 pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
553 uvb_m.pp("embed_tokens").add(&self.wte);
554 uvb_m.pp("norm").add(&self.ln_f);
555
556 for (layer_idx, layer) in self.blocks.iter().enumerate() {
557 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
558 uvb_l.pp("input_layernorm").add(&layer.rms_1);
559 uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
560 }
561
562 uvb_m.to_safetensors()
563 }
564}
565
566impl IsqModel for Llama {
567 fn get_layers(
568 &mut self,
569 ) -> (
570 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
571 &dyn DeviceMapper,
572 ) {
573 let mut tensors = Vec::new();
574 tensors.push((&mut self.lm_head, None));
575 for (i, layer) in self.blocks.iter_mut().enumerate() {
576 tensors.push((&mut layer.attn.q_proj, Some(i)));
577 tensors.push((&mut layer.attn.k_proj, Some(i)));
578 tensors.push((&mut layer.attn.v_proj, Some(i)));
579 tensors.push((&mut layer.attn.o_proj, Some(i)));
580 tensors.extend(
581 layer
582 .mlp
583 .get_isq_layers()
584 .into_iter()
585 .map(|m| (m, Some(i)))
586 .collect::<Vec<_>>(),
587 );
588 }
589 (tensors, &*self.mapper)
590 }
591
592 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
593 let uvb = UnVarBuilder::new();
594 self.residual_tensors_m(uvb.pp("model"))
595 }
596
597 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
598 let mut names = Vec::new();
600 names.push(None);
602 for i in 0..self.blocks.len() {
603 names.push(Some(format!("blk.{i}.attn_q.weight")));
604 names.push(Some(format!("blk.{i}.attn_k.weight")));
605 names.push(Some(format!("blk.{i}.attn_v.weight")));
606 names.push(Some(format!("blk.{i}.attn_output.weight")));
607 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
608 names.push(Some(format!("blk.{i}.ffn_up.weight")));
609 names.push(Some(format!("blk.{i}.ffn_down.weight")));
610 }
611 Ok(names)
612 }
613}
614
615impl NormalModel for Llama {
616 fn forward(
617 &self,
618 input_ids: &Tensor,
619 seqlen_offsets: &[usize],
620 context_lens: Vec<(usize, usize)>,
621 _position_ids: Vec<usize>,
622 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
623 flash_params: &FlashParams,
624 ) -> Result<Tensor> {
625 self.forward(
626 input_ids,
627 seqlen_offsets,
628 context_lens,
629 metadata,
630 flash_params,
631 )
632 }
633 fn xlora_forward(
634 &self,
635 _input_ids: &Tensor,
636 _input_ids_full: &Tensor,
637 _seqlen_offsets: &[usize],
638 _seqlen_offsets_full: &[usize],
639 _no_kv_cache: bool,
640 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
641 _context_lens: Vec<(usize, usize)>,
642 _position_ids: Vec<usize>,
643 _flash_params: &FlashParams,
644 _flash_params_full: &FlashParams,
645 ) -> Result<Tensor> {
646 unimplemented!()
647 }
648 fn cache(&self) -> &crate::pipeline::EitherCache {
649 &self.kv_cache
650 }
651 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
652 &mut self.kv_cache
653 }
654 fn device(&self) -> &Device {
655 &self.device
656 }
657 fn is_xlora(&self) -> bool {
658 false
659 }
660 fn max_seq_len(&self) -> usize {
661 self.blocks[0].attn.max_seq_len
662 }
663 fn config(&self) -> &ModelConfigMetadata {
664 &self.cfg
665 }
666}
667
668impl AnyMoeBaseModelMixin for Llama {
669 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
670 let mut mlps = Vec::new();
671 for layer in &self.blocks {
672 mlps.push(&*layer.mlp);
673 }
674 mlps
675 }
676 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
677 let mut mlps = Vec::new();
678 for layer in &mut self.blocks {
679 mlps.push(&mut layer.mlp);
680 }
681 mlps
682 }
683 fn create_anymoe_layers(
684 &mut self,
685 additional_vbs: Vec<ShardedVarBuilder>,
686 config: AnyMoeConfig,
687 (prefix, mlp): (String, String),
688 mut layers: Vec<usize>,
689 expert_type: AnyMoeExpertType,
690 gate_vb: Option<ShardedVarBuilder>,
691 ) -> Result<()> {
692 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
693 if layers.is_empty() {
694 layers = (0..self.blocks.len()).collect::<Vec<_>>();
695 }
696 for _ in 0..layers.len() {
697 experts.push(Vec::new());
698 }
699 for vb in additional_vbs {
700 let vb = vb.pp(&prefix);
701 for (layer, row) in experts.iter_mut().enumerate() {
702 if !layers.contains(&layer) {
703 continue;
704 }
705
706 let intermediate_size = self.blocks[layer].mlp.get_params()[1];
707 let hidden_size = self.blocks[layer].mlp.get_params()[0];
708 match expert_type {
709 AnyMoeExpertType::FineTuned => {
710 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
711 row.push(Box::new(Mlp::replicate(
712 self.blocks[layer].mlp.get_params(),
713 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
714 self.blocks[layer].mlp.hidden_act(),
715 &self.mapper.get_comm_for(layer)?,
716 )?));
717 }
718 AnyMoeExpertType::LoraAdapter {
719 rank,
720 alpha,
721 ref target_modules,
722 } => {
723 let vb_mlp = vb.pp(layer).pp(&mlp);
724
725 let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
726 Some(get_delta_from_lora_ab!(
727 vb_mlp,
728 rank,
729 alpha,
730 (hidden_size, intermediate_size),
731 "c_fc1"
732 ))
733 } else {
734 None
735 };
736 let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
737 Some(get_delta_from_lora_ab!(
738 vb_mlp,
739 rank,
740 alpha,
741 (hidden_size, intermediate_size),
742 "c_fc2"
743 ))
744 } else {
745 None
746 };
747 let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
748 Some(get_delta_from_lora_ab!(
749 vb_mlp,
750 rank,
751 alpha,
752 (intermediate_size, hidden_size),
753 "c_proj"
754 ))
755 } else {
756 None
757 };
758
759 row.push(self.blocks[layer].mlp.new_added_delta(vec![
760 c_fc1_delta,
761 c_fc2_delta,
762 c_proj_delta,
763 ])?);
764 }
765 }
766 }
767 }
768 for (layer, expert) in layers.into_iter().zip(experts) {
769 let mut experts_all = vec![self.blocks[layer].mlp.clone()];
770 experts_all.extend(expert);
771 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
772 self.blocks[layer].mlp = Box::new(MoeMlp::new(
773 experts_all,
774 config.clone(),
775 dtype,
776 &device,
777 layer,
778 gate_vb.as_ref(),
779 )?);
780 }
781 Ok(())
782 }
783 fn amoe_supported(&self) -> bool {
784 true
785 }
786}