1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7 ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{
18 embedding, Activation, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, Mlp,
19 RmsNorm, Sdpa,
20 },
21 layers_masker::PastKvLenCache,
22 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23 pipeline::{
24 extract_logits,
25 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27 },
28 serde_default_fn,
29 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33serde_default_fn!(bool, use_flash_attn_default, false);
34
35#[derive(Debug, Clone, Deserialize, Serialize, Default)]
36pub struct Config {
37 pub hidden_act: Activation,
38 pub hidden_size: usize,
39 pub intermediate_size: usize,
40 pub vocab_size: usize,
41 pub num_hidden_layers: usize,
42 pub num_attention_heads: usize,
43 pub num_key_value_heads: usize,
44 #[serde(default = "use_flash_attn_default")]
45 pub use_flash_attn: bool,
46 pub rms_norm_eps: f64,
47 pub rope_theta: f32,
48 pub max_position_embeddings: usize,
49 pub rope_scaling: Option<Llama3RopeConfig>,
50 pub quantization_config: Option<QuantizedConfig>,
51 #[serde(default = "word_emb_default")]
52 pub tie_word_embeddings: bool,
53}
54
55struct CausalSelfAttention {
56 q_proj: Arc<dyn QuantMethod>,
57 k_proj: Arc<dyn QuantMethod>,
58 v_proj: Arc<dyn QuantMethod>,
59 o_proj: Arc<dyn QuantMethod>,
60 num_attention_heads: usize,
61 num_key_value_heads: usize,
62 head_dim: usize,
63 rotary_emb: Arc<Llama3RotaryEmbedding>,
64 max_seq_len: usize,
65 paged_attn: Option<PagedAttention>,
66 sdpa_params: SdpaParams,
67}
68
69impl CausalSelfAttention {
70 #[allow(clippy::too_many_arguments)]
71 fn forward(
72 &self,
73 x: &Tensor,
74 attention_mask: &Option<Tensor>,
75 seqlen_offsets: &[usize],
76 kv_cache: &mut KvCache,
77 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
78 flash_params: &FlashParams,
79 ) -> Result<Tensor> {
80 let (b_sz, seq_len, _) = x.dims3()?;
81
82 let original_dtype = x.dtype();
83 let mut x = x.clone();
84 if let Some(t) = self.q_proj.quantized_act_type() {
85 x = x.to_dtype(t)?;
86 }
87 let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
88 let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
89 let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
90 if self.q_proj.quantized_act_type().is_some() {
91 q = q.to_dtype(original_dtype)?;
92 k = k.to_dtype(original_dtype)?;
93 v = v.to_dtype(original_dtype)?;
94 }
95
96 let (q, k, v) = if seq_len != 1 {
97 let q = q
98 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
99 .transpose(1, 2)?;
100 let k = k
101 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
102 .transpose(1, 2)?;
103 let v = v
104 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
105 .transpose(1, 2)?;
106 (q, k, v)
107 } else {
108 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
109 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
110 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
111 (q, k, v)
112 };
113
114 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
115
116 let mut y = match &self.paged_attn {
117 Some(paged_attn) => match metadata {
118 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
119 &q,
120 &k,
121 &v,
122 attention_mask.clone().as_ref(),
123 Some(key_cache),
124 Some(value_cache),
125 input_metadata,
126 &self.sdpa_params,
127 Some(flash_params),
128 )?,
129 None => {
130 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
133 assert!(attention_mask.is_some());
135 paged_attn.forward(
136 &q,
137 &k,
138 &v,
139 attention_mask.clone().as_ref(),
140 None,
141 None,
142 &input_metadata,
143 &self.sdpa_params,
144 Some(flash_params),
145 )?
146 }
147 },
148 None => {
149 let (k, v) = kv_cache.append(&k, &v)?;
150
151 Sdpa.run_attention(
152 &q,
153 &k,
154 &v,
155 attention_mask.clone().as_ref(),
156 Some(flash_params),
157 &self.sdpa_params,
158 )?
159 }
160 };
161
162 if let Some(t) = self.q_proj.quantized_act_type() {
163 y = y.to_dtype(t)?;
164 }
165 y = if attention_mask.is_some() {
166 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
167 } else {
168 y.reshape((b_sz, seq_len, ()))?
169 };
170 let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
171 if self.q_proj.quantized_act_type().is_some() {
172 res = res.to_dtype(original_dtype)?;
173 }
174 Ok(res)
175 }
176
177 fn load(
178 vb: ShardedVarBuilder,
179 cfg: &Config,
180 rope: Arc<Llama3RotaryEmbedding>,
181 paged_attn: Option<PagedAttention>,
182 comm: &Arc<mistralrs_quant::Comm>,
183 ) -> Result<Self> {
184 let size_in = cfg.hidden_size;
185 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
186 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
187 let q_proj = ColumnParallelLayer::new(
188 size_in,
189 size_q,
190 &cfg.quantization_config,
191 false,
192 comm,
193 vb.pp("q_proj"),
194 )?;
195 let kv_shard = mistralrs_quant::compute_kv_shard(
196 cfg.num_key_value_heads,
197 cfg.hidden_size / cfg.num_attention_heads,
198 comm,
199 );
200 let k_proj = ColumnParallelLayer::new_with_shard(
201 size_in,
202 size_kv,
203 &cfg.quantization_config,
204 false,
205 comm,
206 kv_shard,
207 vb.pp("k_proj"),
208 )?;
209 let v_proj = ColumnParallelLayer::new_with_shard(
210 size_in,
211 size_kv,
212 &cfg.quantization_config,
213 false,
214 comm,
215 kv_shard,
216 vb.pp("v_proj"),
217 )?;
218 let o_proj = RowParallelLayer::new(
219 size_q,
220 size_in,
221 &cfg.quantization_config,
222 false,
223 comm,
224 vb.pp("o_proj"),
225 )?;
226 Ok(Self {
227 q_proj,
228 k_proj,
229 v_proj,
230 o_proj,
231 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
232 num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
233 head_dim: cfg.hidden_size / cfg.num_attention_heads,
234 rotary_emb: rope,
235 max_seq_len: cfg.max_position_embeddings,
236 paged_attn,
237 sdpa_params: SdpaParams {
238 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
239 cfg.num_key_value_heads,
240 cfg.num_attention_heads,
241 comm,
242 ),
243 use_flash_attn: cfg.use_flash_attn,
244 softcap: None,
245 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
246 sliding_window: None,
247 },
248 })
249 }
250}
251
252struct Block {
253 rms_1: RmsNorm,
254 attn: CausalSelfAttention,
255 rms_2: RmsNorm,
256 mlp: Box<dyn MlpLayer>,
257}
258
259impl Block {
260 #[allow(clippy::too_many_arguments)]
261 fn forward(
262 &self,
263 x: &Tensor,
264 attention_mask: &Option<Tensor>,
265 seqlen_offsets: &[usize],
266 kv_cache: &mut KvCache,
267 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
268 flash_params: &FlashParams,
269 ) -> Result<Tensor> {
270 let residual = x;
271 let x = self.rms_1.forward(x)?;
272 let x = (self.attn.forward(
273 &x,
274 attention_mask,
275 seqlen_offsets,
276 kv_cache,
277 metadata,
278 flash_params,
279 )? + residual)?;
280 let residual = &x;
281 let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
282 Ok(x)
283 }
284
285 #[allow(clippy::too_many_arguments)]
286 fn load(
287 vb: ShardedVarBuilder,
288 cfg: &Config,
289 mapper: &dyn DeviceMapper,
290 layer_idx: usize,
291 loading_isq: bool,
292 rope: Arc<Llama3RotaryEmbedding>,
293 paged_attn: Option<PagedAttention>,
294 comm: &Arc<mistralrs_quant::Comm>,
295 ) -> Result<Self> {
296 let attn = CausalSelfAttention::load(
297 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298 cfg,
299 rope,
300 paged_attn,
301 comm,
302 )?;
303 let mlp = Mlp::new(
304 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
305 cfg.hidden_size,
306 cfg.intermediate_size,
307 &cfg.quantization_config,
308 cfg.hidden_act,
309 comm,
310 )?;
311 let rms_1 = RmsNorm::new(
312 cfg.hidden_size,
313 cfg.rms_norm_eps,
314 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
315 )?;
316 let rms_2 = RmsNorm::new(
317 cfg.hidden_size,
318 cfg.rms_norm_eps,
319 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
320 )?;
321 Ok(Self {
322 rms_1,
323 attn,
324 rms_2,
325 mlp: Box::new(mlp),
326 })
327 }
328}
329
330pub struct Llama {
331 wte: Embedding,
332 blocks: Vec<Block>,
333 ln_f: RmsNorm,
334 lm_head: Arc<dyn QuantMethod>,
335 kv_cache: crate::pipeline::EitherCache,
336 device: Device,
337 mapper: Box<dyn DeviceMapper + Send + Sync>,
338 cfg: ModelConfigMetadata,
339}
340
341impl Llama {
342 pub fn new(
343 cfg: &Config,
344 vb: ShardedVarBuilder,
345 is_gptx: bool,
346 normal_loading_metadata: NormalLoadingMetadata,
347 attention_mechanism: AttentionImplementation,
348 ) -> Result<Self> {
349 let vb_m = vb.pp("model");
350 let vb_lm_head = vb.pp("lm_head");
351 Self::new_inner(
352 cfg,
353 vb_m,
354 vb_lm_head,
355 is_gptx,
356 normal_loading_metadata,
357 attention_mechanism,
358 )
359 }
360
361 pub fn new_inner(
362 cfg: &Config,
363 vb_m: ShardedVarBuilder,
364 vb_lm_head: ShardedVarBuilder,
365 is_gptx: bool,
366 normal_loading_metadata: NormalLoadingMetadata,
367 attention_mechanism: AttentionImplementation,
368 ) -> Result<Self> {
369 if let Some(ref quant_cfg) = &cfg.quantization_config {
370 tracing::info!(
371 "Using {} quantization: {}.",
372 quant_cfg.quant_method.to_string(),
373 quant_cfg.get_bits_name(&vb_m)
374 );
375 }
376 let mapper = normal_loading_metadata.mapper;
377
378 let wte = embedding(
379 cfg.vocab_size,
380 cfg.hidden_size,
381 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
382 )?;
383 let lm_head = if !cfg.tie_word_embeddings {
384 ReplicatedLayer::new(
385 cfg.hidden_size,
386 cfg.vocab_size,
387 &None,
388 false,
389 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
390 )?
391 } else {
392 ReplicatedLayer::from_linear(candle_nn::Linear::new(
393 mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
394 None,
395 ))?
396 };
397 let ln_f = RmsNorm::new(
398 cfg.hidden_size,
399 cfg.rms_norm_eps,
400 mapper.set_nm_device(vb_m.pp("norm"), false),
401 )?;
402 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
403 let mut ropes = HashMap::new();
404 for i in 0..cfg.num_hidden_layers {
405 let device = mapper
406 .device_for(i, false)
407 .unwrap_or(&normal_loading_metadata.real_device);
408 ropes.insert(
409 device.location(),
410 Arc::new(Llama3RotaryEmbedding::new_llama3(
411 vb_m.dtype(),
412 cfg,
413 device,
414 is_gptx,
415 )?),
416 );
417 }
418 let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
419 0..cfg.num_hidden_layers,
420 "Loading repeating layers",
421 &normal_loading_metadata.multi_progress,
422 )
423 .into_iter()
424 .map(|i| {
425 let device = mapper
426 .device_for(i, false)
427 .unwrap_or(&normal_loading_metadata.real_device);
428 let rotary_emb = ropes
429 .get(&device.location())
430 .expect("No RoPE for device location!")
431 .clone();
432 let paged_attn = match &attention_mechanism {
433 AttentionImplementation::Eager => None,
434 AttentionImplementation::PagedAttention => Some(
435 PagedAttention::new(head_dim, device, None)
436 .expect("Failed to create PagedAttention"),
437 ),
438 };
439 let comm = mapper.get_comm_for(i).unwrap();
440 Block::load(
441 vb_m.pp(format!("layers.{i}")),
442 cfg,
443 &*mapper,
444 i,
445 normal_loading_metadata.loading_isq,
446 rotary_emb,
447 paged_attn,
448 &comm,
449 )
450 .expect("Failed to load block.")
451 })
452 .collect();
453
454 Ok(Self {
455 wte,
456 blocks,
457 ln_f,
458 lm_head,
459 kv_cache: EitherCache::Normal(NormalCache::new(
460 cfg.num_hidden_layers,
461 cfg.max_position_embeddings,
462 )),
463 device: normal_loading_metadata.real_device,
464 cfg: ModelConfigMetadata {
465 max_seq_len: cfg.max_position_embeddings,
466 num_layers: cfg.num_hidden_layers,
467 hidden_size: cfg.hidden_size,
468 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
469 .max(1),
470 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
471 sliding_window: None,
472 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
473 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
474 },
475 mapper,
476 })
477 }
478
479 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
480 self.wte.forward(input_ids)
481 }
482
483 pub fn forward(
484 &self,
485 input_ids: &Tensor,
486 seqlen_offsets: &[usize],
487 context_lens: Vec<(usize, usize)>,
488 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
489 flash_params: &FlashParams,
490 ) -> Result<Tensor> {
491 self.forward_embeds(
492 input_ids,
493 self.wte.forward(input_ids)?,
494 seqlen_offsets,
495 context_lens,
496 metadata,
497 flash_params,
498 )
499 }
500
501 #[allow(clippy::too_many_arguments)]
502 pub fn forward_embeds(
503 &self,
504 input_ids: &Tensor,
505 input_embeds: Tensor,
506 seqlen_offsets: &[usize],
507 context_lens: Vec<(usize, usize)>,
508 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
509 flash_params: &FlashParams,
510 ) -> Result<Tensor> {
511 let mut x = input_embeds;
512 let cache = &mut self.kv_cache.normal().0;
513 let mask = CausalMasker.make_causal_mask_matrix(
514 input_ids,
515 metadata
516 .as_ref()
517 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
518 .unwrap_or(cache as &dyn PastKvLenCache),
519 x.dtype(),
520 self.blocks[0].attn.num_attention_heads,
521 )?;
522 let mask = mask.filter(|_| {
524 metadata
525 .as_ref()
526 .map(|(_, meta)| meta.is_first_prompt_chunk)
527 .unwrap_or(true)
528 });
529 for (block_idx, block) in self.blocks.iter().enumerate() {
530 x = self.mapper.map(x, block_idx)?;
531 x = block.forward(
532 &x,
533 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
534 seqlen_offsets,
535 &mut cache[block_idx],
536 metadata
537 .as_ref()
538 .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
539 flash_params,
540 )?;
541 }
542 let x = x.to_device(&self.device)?;
543 let mut x = self.ln_f.forward(&x)?;
544 if let Some(t) = self.lm_head.quantized_act_type() {
545 x = x.to_dtype(t)?;
546 }
547 let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
548 extract_logits(&xs, context_lens)
549 }
550
551 pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
552 uvb_m.pp("embed_tokens").add(&self.wte);
553 uvb_m.pp("norm").add(&self.ln_f);
554
555 for (layer_idx, layer) in self.blocks.iter().enumerate() {
556 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
557 uvb_l.pp("input_layernorm").add(&layer.rms_1);
558 uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
559 }
560
561 uvb_m.to_safetensors()
562 }
563}
564
565impl IsqModel for Llama {
566 fn get_layers(
567 &mut self,
568 ) -> (
569 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
570 &dyn DeviceMapper,
571 ) {
572 let mut tensors = Vec::new();
573 tensors.push((&mut self.lm_head, None));
574 for (i, layer) in self.blocks.iter_mut().enumerate() {
575 tensors.push((&mut layer.attn.q_proj, Some(i)));
576 tensors.push((&mut layer.attn.k_proj, Some(i)));
577 tensors.push((&mut layer.attn.v_proj, Some(i)));
578 tensors.push((&mut layer.attn.o_proj, Some(i)));
579 tensors.extend(
580 layer
581 .mlp
582 .get_isq_layers()
583 .into_iter()
584 .map(|m| (m, Some(i)))
585 .collect::<Vec<_>>(),
586 );
587 }
588 (tensors, &*self.mapper)
589 }
590
591 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
592 let uvb = UnVarBuilder::new();
593 self.residual_tensors_m(uvb.pp("model"))
594 }
595
596 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
597 let mut names = Vec::new();
599 names.push(None);
601 for i in 0..self.blocks.len() {
602 names.push(Some(format!("blk.{i}.attn_q.weight")));
603 names.push(Some(format!("blk.{i}.attn_k.weight")));
604 names.push(Some(format!("blk.{i}.attn_v.weight")));
605 names.push(Some(format!("blk.{i}.attn_output.weight")));
606 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
607 names.push(Some(format!("blk.{i}.ffn_up.weight")));
608 names.push(Some(format!("blk.{i}.ffn_down.weight")));
609 }
610 Ok(names)
611 }
612}
613
614impl NormalModel for Llama {
615 fn forward(
616 &self,
617 input_ids: &Tensor,
618 seqlen_offsets: &[usize],
619 context_lens: Vec<(usize, usize)>,
620 _position_ids: Vec<usize>,
621 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
622 flash_params: &FlashParams,
623 ) -> Result<Tensor> {
624 self.forward(
625 input_ids,
626 seqlen_offsets,
627 context_lens,
628 metadata,
629 flash_params,
630 )
631 }
632 fn xlora_forward(
633 &self,
634 _input_ids: &Tensor,
635 _input_ids_full: &Tensor,
636 _seqlen_offsets: &[usize],
637 _seqlen_offsets_full: &[usize],
638 _no_kv_cache: bool,
639 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
640 _context_lens: Vec<(usize, usize)>,
641 _position_ids: Vec<usize>,
642 _flash_params: &FlashParams,
643 _flash_params_full: &FlashParams,
644 ) -> Result<Tensor> {
645 unimplemented!()
646 }
647 fn cache(&self) -> &crate::pipeline::EitherCache {
648 &self.kv_cache
649 }
650 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
651 &mut self.kv_cache
652 }
653 fn device(&self) -> &Device {
654 &self.device
655 }
656 fn is_xlora(&self) -> bool {
657 false
658 }
659 fn max_seq_len(&self) -> usize {
660 self.blocks[0].attn.max_seq_len
661 }
662 fn config(&self) -> &ModelConfigMetadata {
663 &self.cfg
664 }
665}
666
667impl AnyMoeBaseModelMixin for Llama {
668 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
669 let mut mlps = Vec::new();
670 for layer in &self.blocks {
671 mlps.push(&*layer.mlp);
672 }
673 mlps
674 }
675 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
676 let mut mlps = Vec::new();
677 for layer in &mut self.blocks {
678 mlps.push(&mut layer.mlp);
679 }
680 mlps
681 }
682 fn create_anymoe_layers(
683 &mut self,
684 additional_vbs: Vec<ShardedVarBuilder>,
685 config: AnyMoeConfig,
686 (prefix, mlp): (String, String),
687 mut layers: Vec<usize>,
688 expert_type: AnyMoeExpertType,
689 gate_vb: Option<ShardedVarBuilder>,
690 ) -> Result<()> {
691 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
692 if layers.is_empty() {
693 layers = (0..self.blocks.len()).collect::<Vec<_>>();
694 }
695 for _ in 0..layers.len() {
696 experts.push(Vec::new());
697 }
698 for vb in additional_vbs {
699 let vb = vb.pp(&prefix);
700 for (layer, row) in experts.iter_mut().enumerate() {
701 if !layers.contains(&layer) {
702 continue;
703 }
704
705 let intermediate_size = self.blocks[layer].mlp.get_params()[1];
706 let hidden_size = self.blocks[layer].mlp.get_params()[0];
707 match expert_type {
708 AnyMoeExpertType::FineTuned => {
709 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
710 row.push(Box::new(Mlp::replicate(
711 self.blocks[layer].mlp.get_params(),
712 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
713 self.blocks[layer].mlp.hidden_act(),
714 &self.mapper.get_comm_for(layer)?,
715 )?));
716 }
717 AnyMoeExpertType::LoraAdapter {
718 rank,
719 alpha,
720 ref target_modules,
721 } => {
722 let vb_mlp = vb.pp(layer).pp(&mlp);
723
724 let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
725 Some(get_delta_from_lora_ab!(
726 vb_mlp,
727 rank,
728 alpha,
729 (hidden_size, intermediate_size),
730 "c_fc1"
731 ))
732 } else {
733 None
734 };
735 let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
736 Some(get_delta_from_lora_ab!(
737 vb_mlp,
738 rank,
739 alpha,
740 (hidden_size, intermediate_size),
741 "c_fc2"
742 ))
743 } else {
744 None
745 };
746 let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
747 Some(get_delta_from_lora_ab!(
748 vb_mlp,
749 rank,
750 alpha,
751 (intermediate_size, hidden_size),
752 "c_proj"
753 ))
754 } else {
755 None
756 };
757
758 row.push(self.blocks[layer].mlp.new_added_delta(vec![
759 c_fc1_delta,
760 c_fc2_delta,
761 c_proj_delta,
762 ])?);
763 }
764 }
765 }
766 }
767 for (layer, expert) in layers.into_iter().zip(experts) {
768 let mut experts_all = vec![self.blocks[layer].mlp.clone()];
769 experts_all.extend(expert);
770 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
771 self.blocks[layer].mlp = Box::new(MoeMlp::new(
772 experts_all,
773 config.clone(),
774 dtype,
775 &device,
776 layer,
777 gate_vb.as_ref(),
778 )?);
779 }
780 Ok(())
781 }
782 fn amoe_supported(&self) -> bool {
783 true
784 }
785}