1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7 ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{
18 embedding, Activation, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, Mlp,
19 RmsNorm, Sdpa,
20 },
21 layers_masker::PastKvLenCache,
22 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23 pipeline::{
24 extract_logits,
25 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27 },
28 serde_default_fn,
29 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34#[derive(Debug, Clone, Deserialize, Serialize, Default)]
35pub struct Config {
36 pub hidden_act: Activation,
37 pub hidden_size: usize,
38 pub intermediate_size: usize,
39 pub vocab_size: usize,
40 pub num_hidden_layers: usize,
41 pub num_attention_heads: usize,
42 pub num_key_value_heads: usize,
43 pub rms_norm_eps: f64,
44 pub rope_theta: f32,
45 pub max_position_embeddings: usize,
46 pub rope_scaling: Option<Llama3RopeConfig>,
47 pub quantization_config: Option<QuantizedConfig>,
48 #[serde(default = "word_emb_default")]
49 pub tie_word_embeddings: bool,
50}
51
52struct CausalSelfAttention {
53 q_proj: Arc<dyn QuantMethod>,
54 k_proj: Arc<dyn QuantMethod>,
55 v_proj: Arc<dyn QuantMethod>,
56 o_proj: Arc<dyn QuantMethod>,
57 num_attention_heads: usize,
58 num_key_value_heads: usize,
59 head_dim: usize,
60 rotary_emb: Arc<Llama3RotaryEmbedding>,
61 max_seq_len: usize,
62 paged_attn: Option<PagedAttention>,
63 sdpa_params: SdpaParams,
64}
65
66impl CausalSelfAttention {
67 #[allow(clippy::too_many_arguments)]
68 fn forward(
69 &self,
70 x: &Tensor,
71 attention_mask: &Option<Tensor>,
72 seqlen_offsets: &[usize],
73 kv_cache: &mut KvCache,
74 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
75 flash_params: &FlashParams,
76 ) -> Result<Tensor> {
77 let (b_sz, seq_len, _) = x.dims3()?;
78
79 let original_dtype = x.dtype();
80 let mut x = x.clone();
81 if let Some(t) = self.q_proj.quantized_act_type() {
82 x = x.to_dtype(t)?;
83 }
84 let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
85 let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
86 let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
87 if self.q_proj.quantized_act_type().is_some() {
88 q = q.to_dtype(original_dtype)?;
89 k = k.to_dtype(original_dtype)?;
90 v = v.to_dtype(original_dtype)?;
91 }
92
93 let (q, k, v) = if seq_len != 1 {
94 let q = q
95 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
96 .transpose(1, 2)?;
97 let k = k
98 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
99 .transpose(1, 2)?;
100 let v = v
101 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
102 .transpose(1, 2)?;
103 (q, k, v)
104 } else {
105 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
106 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
107 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
108 (q, k, v)
109 };
110
111 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
112
113 let mut y = match &self.paged_attn {
114 Some(paged_attn) => match metadata {
115 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
116 &q,
117 &k,
118 &v,
119 attention_mask.clone().as_ref(),
120 Some(key_cache),
121 Some(value_cache),
122 input_metadata,
123 &self.sdpa_params,
124 Some(flash_params),
125 )?,
126 None => {
127 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
130 assert!(attention_mask.is_some());
132 paged_attn.forward(
133 &q,
134 &k,
135 &v,
136 attention_mask.clone().as_ref(),
137 None,
138 None,
139 &input_metadata,
140 &self.sdpa_params,
141 Some(flash_params),
142 )?
143 }
144 },
145 None => {
146 let (k, v) = kv_cache.append(&k, &v)?;
147
148 Sdpa.run_attention(
149 &q,
150 &k,
151 &v,
152 attention_mask.clone().as_ref(),
153 Some(flash_params),
154 &self.sdpa_params,
155 )?
156 }
157 };
158
159 if let Some(t) = self.q_proj.quantized_act_type() {
160 y = y.to_dtype(t)?;
161 }
162 y = if attention_mask.is_some() {
163 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
164 } else {
165 y.reshape((b_sz, seq_len, ()))?
166 };
167 let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
168 if self.q_proj.quantized_act_type().is_some() {
169 res = res.to_dtype(original_dtype)?;
170 }
171 Ok(res)
172 }
173
174 fn load(
175 vb: ShardedVarBuilder,
176 cfg: &Config,
177 rope: Arc<Llama3RotaryEmbedding>,
178 paged_attn: Option<PagedAttention>,
179 comm: &Arc<mistralrs_quant::Comm>,
180 ) -> Result<Self> {
181 let size_in = cfg.hidden_size;
182 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
183 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
184 let q_proj = ColumnParallelLayer::new(
185 size_in,
186 size_q,
187 &cfg.quantization_config,
188 false,
189 comm,
190 vb.pp("q_proj"),
191 )?;
192 let kv_shard = mistralrs_quant::compute_kv_shard(
193 cfg.num_key_value_heads,
194 cfg.hidden_size / cfg.num_attention_heads,
195 comm,
196 );
197 let k_proj = ColumnParallelLayer::new_with_shard(
198 size_in,
199 size_kv,
200 &cfg.quantization_config,
201 false,
202 comm,
203 kv_shard,
204 vb.pp("k_proj"),
205 )?;
206 let v_proj = ColumnParallelLayer::new_with_shard(
207 size_in,
208 size_kv,
209 &cfg.quantization_config,
210 false,
211 comm,
212 kv_shard,
213 vb.pp("v_proj"),
214 )?;
215 let o_proj = RowParallelLayer::new(
216 size_q,
217 size_in,
218 &cfg.quantization_config,
219 false,
220 comm,
221 vb.pp("o_proj"),
222 )?;
223 Ok(Self {
224 q_proj,
225 k_proj,
226 v_proj,
227 o_proj,
228 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
229 num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
230 head_dim: cfg.hidden_size / cfg.num_attention_heads,
231 rotary_emb: rope,
232 max_seq_len: cfg.max_position_embeddings,
233 paged_attn,
234 sdpa_params: SdpaParams {
235 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
236 cfg.num_key_value_heads,
237 cfg.num_attention_heads,
238 comm,
239 ),
240 softcap: None,
241 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
242 sliding_window: None,
243 },
244 })
245 }
246}
247
248struct Block {
249 rms_1: RmsNorm,
250 attn: CausalSelfAttention,
251 rms_2: RmsNorm,
252 mlp: Box<dyn MlpLayer>,
253}
254
255impl Block {
256 #[allow(clippy::too_many_arguments)]
257 fn forward(
258 &self,
259 x: &Tensor,
260 attention_mask: &Option<Tensor>,
261 seqlen_offsets: &[usize],
262 kv_cache: &mut KvCache,
263 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
264 flash_params: &FlashParams,
265 ) -> Result<Tensor> {
266 let residual = x;
267 let x = self.rms_1.forward(x)?;
268 let x = (self.attn.forward(
269 &x,
270 attention_mask,
271 seqlen_offsets,
272 kv_cache,
273 metadata,
274 flash_params,
275 )? + residual)?;
276 let residual = &x;
277 let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
278 Ok(x)
279 }
280
281 #[allow(clippy::too_many_arguments)]
282 fn load(
283 vb: ShardedVarBuilder,
284 cfg: &Config,
285 mapper: &dyn DeviceMapper,
286 layer_idx: usize,
287 loading_isq: bool,
288 rope: Arc<Llama3RotaryEmbedding>,
289 paged_attn: Option<PagedAttention>,
290 comm: &Arc<mistralrs_quant::Comm>,
291 ) -> Result<Self> {
292 let attn = CausalSelfAttention::load(
293 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
294 cfg,
295 rope,
296 paged_attn,
297 comm,
298 )?;
299 let mlp = Mlp::new(
300 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
301 cfg.hidden_size,
302 cfg.intermediate_size,
303 &cfg.quantization_config,
304 cfg.hidden_act,
305 comm,
306 )?;
307 let rms_1 = RmsNorm::new(
308 cfg.hidden_size,
309 cfg.rms_norm_eps,
310 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
311 )?;
312 let rms_2 = RmsNorm::new(
313 cfg.hidden_size,
314 cfg.rms_norm_eps,
315 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
316 )?;
317 Ok(Self {
318 rms_1,
319 attn,
320 rms_2,
321 mlp: Box::new(mlp),
322 })
323 }
324}
325
326pub struct Llama {
327 wte: Embedding,
328 blocks: Vec<Block>,
329 ln_f: RmsNorm,
330 lm_head: Arc<dyn QuantMethod>,
331 kv_cache: crate::pipeline::EitherCache,
332 device: Device,
333 mapper: Box<dyn DeviceMapper + Send + Sync>,
334 cfg: ModelConfigMetadata,
335}
336
337impl Llama {
338 pub fn new(
339 cfg: &Config,
340 vb: ShardedVarBuilder,
341 is_gptx: bool,
342 normal_loading_metadata: NormalLoadingMetadata,
343 attention_mechanism: AttentionImplementation,
344 ) -> Result<Self> {
345 let vb_m = vb.pp("model");
346 let vb_lm_head = vb.pp("lm_head");
347 Self::new_inner(
348 cfg,
349 vb_m,
350 vb_lm_head,
351 is_gptx,
352 normal_loading_metadata,
353 attention_mechanism,
354 )
355 }
356
357 pub fn new_inner(
358 cfg: &Config,
359 vb_m: ShardedVarBuilder,
360 vb_lm_head: ShardedVarBuilder,
361 is_gptx: bool,
362 normal_loading_metadata: NormalLoadingMetadata,
363 attention_mechanism: AttentionImplementation,
364 ) -> Result<Self> {
365 if let Some(ref quant_cfg) = &cfg.quantization_config {
366 tracing::info!(
367 "Using {} quantization: {}.",
368 quant_cfg.name(),
369 quant_cfg.get_bits_name(&vb_m)
370 );
371 }
372 let mapper = normal_loading_metadata.mapper;
373
374 let wte = embedding(
375 cfg.vocab_size,
376 cfg.hidden_size,
377 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
378 &cfg.quantization_config,
379 )?;
380 let lm_head = if !cfg.tie_word_embeddings {
381 ReplicatedLayer::new(
382 cfg.hidden_size,
383 cfg.vocab_size,
384 &cfg.quantization_config,
385 false,
386 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
387 )?
388 } else {
389 ReplicatedLayer::from_linear(candle_nn::Linear::new(
390 mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
391 None,
392 ))?
393 };
394 let ln_f = RmsNorm::new(
395 cfg.hidden_size,
396 cfg.rms_norm_eps,
397 mapper.set_nm_device(vb_m.pp("norm"), false),
398 )?;
399 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
400 let mut ropes = HashMap::new();
401 for i in 0..cfg.num_hidden_layers {
402 let device = mapper
403 .device_for(i, false)
404 .unwrap_or(&normal_loading_metadata.real_device);
405 ropes.insert(
406 device.location(),
407 Arc::new(Llama3RotaryEmbedding::new_llama3(
408 vb_m.dtype(),
409 cfg,
410 device,
411 is_gptx,
412 )?),
413 );
414 }
415 let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
416 0..cfg.num_hidden_layers,
417 "Loading repeating layers",
418 &normal_loading_metadata.multi_progress,
419 )
420 .par_iter_if_isq(|i| {
421 let device = mapper
422 .device_for(i, false)
423 .unwrap_or(&normal_loading_metadata.real_device);
424 let rotary_emb = ropes
425 .get(&device.location())
426 .expect("No RoPE for device location!")
427 .clone();
428 let paged_attn = match &attention_mechanism {
429 AttentionImplementation::Eager => None,
430 AttentionImplementation::PagedAttention => {
431 Some(PagedAttention::new(head_dim, device, None)?)
432 }
433 };
434 let comm = mapper.get_comm_for(i)?;
435 Block::load(
436 vb_m.pp(format!("layers.{i}")),
437 cfg,
438 &*mapper,
439 i,
440 normal_loading_metadata.loading_isq,
441 rotary_emb,
442 paged_attn,
443 &comm,
444 )
445 })?;
446
447 Ok(Self {
448 wte,
449 blocks,
450 ln_f,
451 lm_head,
452 kv_cache: EitherCache::Normal(NormalCache::new(
453 cfg.num_hidden_layers,
454 cfg.max_position_embeddings,
455 )),
456 device: normal_loading_metadata.real_device,
457 cfg: ModelConfigMetadata {
458 max_seq_len: cfg.max_position_embeddings,
459 num_layers: cfg.num_hidden_layers,
460 hidden_size: cfg.hidden_size,
461 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
462 .max(1),
463 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
464 sliding_window: None,
465 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
466 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
467 },
468 mapper,
469 })
470 }
471
472 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
473 self.wte.forward(input_ids)
474 }
475
476 pub fn forward(
477 &self,
478 input_ids: &Tensor,
479 seqlen_offsets: &[usize],
480 context_lens: Vec<(usize, usize)>,
481 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
482 flash_params: &FlashParams,
483 ) -> Result<Tensor> {
484 self.forward_embeds(
485 input_ids,
486 self.wte.forward(input_ids)?,
487 seqlen_offsets,
488 context_lens,
489 metadata,
490 flash_params,
491 )
492 }
493
494 #[allow(clippy::too_many_arguments)]
495 pub fn forward_embeds(
496 &self,
497 input_ids: &Tensor,
498 input_embeds: Tensor,
499 seqlen_offsets: &[usize],
500 context_lens: Vec<(usize, usize)>,
501 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
502 flash_params: &FlashParams,
503 ) -> Result<Tensor> {
504 let mut x = input_embeds;
505 let cache = &mut self.kv_cache.normal().0;
506 let mask = CausalMasker.make_causal_mask_matrix(
507 input_ids,
508 metadata
509 .as_ref()
510 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
511 .unwrap_or(cache as &dyn PastKvLenCache),
512 x.dtype(),
513 self.blocks[0].attn.num_attention_heads,
514 )?;
515 let mask = mask.filter(|_| {
517 metadata
518 .as_ref()
519 .map(|(_, meta)| meta.is_first_prompt_chunk)
520 .unwrap_or(true)
521 });
522 for (block_idx, block) in self.blocks.iter().enumerate() {
523 x = self.mapper.map(x, block_idx)?;
524 x = block.forward(
525 &x,
526 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
527 seqlen_offsets,
528 &mut cache[block_idx],
529 metadata
530 .as_ref()
531 .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
532 flash_params,
533 )?;
534 }
535 let x = x.to_device(&self.device)?;
536 let mut x = self.ln_f.forward(&x)?;
537 if let Some(t) = self.lm_head.quantized_act_type() {
538 x = x.to_dtype(t)?;
539 }
540 let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
541 extract_logits(&xs, context_lens)
542 }
543
544 pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
545 uvb_m.pp("embed_tokens").add(&self.wte);
546 uvb_m.pp("norm").add(&self.ln_f);
547
548 for (layer_idx, layer) in self.blocks.iter().enumerate() {
549 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
550 uvb_l.pp("input_layernorm").add(&layer.rms_1);
551 uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
552 }
553
554 uvb_m.to_safetensors()
555 }
556}
557
558impl IsqModel for Llama {
559 fn get_layers(
560 &mut self,
561 ) -> (
562 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
563 &dyn DeviceMapper,
564 ) {
565 let mut tensors = Vec::new();
566 tensors.push((&mut self.lm_head, None));
567 for (i, layer) in self.blocks.iter_mut().enumerate() {
568 tensors.push((&mut layer.attn.q_proj, Some(i)));
569 tensors.push((&mut layer.attn.k_proj, Some(i)));
570 tensors.push((&mut layer.attn.v_proj, Some(i)));
571 tensors.push((&mut layer.attn.o_proj, Some(i)));
572 tensors.extend(
573 layer
574 .mlp
575 .get_isq_layers()
576 .into_iter()
577 .map(|m| (m, Some(i)))
578 .collect::<Vec<_>>(),
579 );
580 }
581 (tensors, &*self.mapper)
582 }
583
584 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
585 let uvb = UnVarBuilder::new();
586 self.residual_tensors_m(uvb.pp("model"))
587 }
588
589 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
590 let mut names = Vec::new();
592 names.push(None);
594 for i in 0..self.blocks.len() {
595 names.push(Some(format!("blk.{i}.attn_q.weight")));
596 names.push(Some(format!("blk.{i}.attn_k.weight")));
597 names.push(Some(format!("blk.{i}.attn_v.weight")));
598 names.push(Some(format!("blk.{i}.attn_output.weight")));
599 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
600 names.push(Some(format!("blk.{i}.ffn_up.weight")));
601 names.push(Some(format!("blk.{i}.ffn_down.weight")));
602 }
603 Ok(names)
604 }
605}
606
607impl NormalModel for Llama {
608 fn forward(
609 &self,
610 input_ids: &Tensor,
611 seqlen_offsets: &[usize],
612 context_lens: Vec<(usize, usize)>,
613 _position_ids: Vec<usize>,
614 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
615 flash_params: &FlashParams,
616 ) -> Result<Tensor> {
617 self.forward(
618 input_ids,
619 seqlen_offsets,
620 context_lens,
621 metadata,
622 flash_params,
623 )
624 }
625 fn xlora_forward(
626 &self,
627 _input_ids: &Tensor,
628 _input_ids_full: &Tensor,
629 _seqlen_offsets: &[usize],
630 _seqlen_offsets_full: &[usize],
631 _no_kv_cache: bool,
632 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
633 _context_lens: Vec<(usize, usize)>,
634 _position_ids: Vec<usize>,
635 _flash_params: &FlashParams,
636 _flash_params_full: &FlashParams,
637 ) -> Result<Tensor> {
638 unimplemented!()
639 }
640 fn cache(&self) -> &crate::pipeline::EitherCache {
641 &self.kv_cache
642 }
643 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
644 &mut self.kv_cache
645 }
646 fn device(&self) -> &Device {
647 &self.device
648 }
649 fn is_xlora(&self) -> bool {
650 false
651 }
652 fn max_seq_len(&self) -> usize {
653 self.blocks[0].attn.max_seq_len
654 }
655 fn config(&self) -> &ModelConfigMetadata {
656 &self.cfg
657 }
658}
659
660impl AnyMoeBaseModelMixin for Llama {
661 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
662 let mut mlps = Vec::new();
663 for layer in &self.blocks {
664 mlps.push(&*layer.mlp);
665 }
666 mlps
667 }
668 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
669 let mut mlps = Vec::new();
670 for layer in &mut self.blocks {
671 mlps.push(&mut layer.mlp);
672 }
673 mlps
674 }
675 fn create_anymoe_layers(
676 &mut self,
677 additional_vbs: Vec<ShardedVarBuilder>,
678 config: AnyMoeConfig,
679 (prefix, mlp): (String, String),
680 mut layers: Vec<usize>,
681 expert_type: AnyMoeExpertType,
682 gate_vb: Option<ShardedVarBuilder>,
683 ) -> Result<()> {
684 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
685 if layers.is_empty() {
686 layers = (0..self.blocks.len()).collect::<Vec<_>>();
687 }
688 for _ in 0..layers.len() {
689 experts.push(Vec::new());
690 }
691 for vb in additional_vbs {
692 let vb = vb.pp(&prefix);
693 for (layer, row) in experts.iter_mut().enumerate() {
694 if !layers.contains(&layer) {
695 continue;
696 }
697
698 let intermediate_size = self.blocks[layer].mlp.get_params()[1];
699 let hidden_size = self.blocks[layer].mlp.get_params()[0];
700 match expert_type {
701 AnyMoeExpertType::FineTuned => {
702 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
703 row.push(Box::new(Mlp::replicate(
704 self.blocks[layer].mlp.get_params(),
705 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
706 self.blocks[layer].mlp.hidden_act(),
707 &self.mapper.get_comm_for(layer)?,
708 )?));
709 }
710 AnyMoeExpertType::LoraAdapter {
711 rank,
712 alpha,
713 ref target_modules,
714 } => {
715 let vb_mlp = vb.pp(layer).pp(&mlp);
716
717 let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
718 Some(get_delta_from_lora_ab!(
719 vb_mlp,
720 rank,
721 alpha,
722 (hidden_size, intermediate_size),
723 "c_fc1"
724 ))
725 } else {
726 None
727 };
728 let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
729 Some(get_delta_from_lora_ab!(
730 vb_mlp,
731 rank,
732 alpha,
733 (hidden_size, intermediate_size),
734 "c_fc2"
735 ))
736 } else {
737 None
738 };
739 let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
740 Some(get_delta_from_lora_ab!(
741 vb_mlp,
742 rank,
743 alpha,
744 (intermediate_size, hidden_size),
745 "c_proj"
746 ))
747 } else {
748 None
749 };
750
751 row.push(self.blocks[layer].mlp.new_added_delta(vec![
752 c_fc1_delta,
753 c_fc2_delta,
754 c_proj_delta,
755 ])?);
756 }
757 }
758 }
759 }
760 for (layer, expert) in layers.into_iter().zip(experts) {
761 let mut experts_all = vec![self.blocks[layer].mlp.clone()];
762 experts_all.extend(expert);
763 let (dtype, device) = self.blocks[layer].mlp.dtype_device();
764 self.blocks[layer].mlp = Box::new(MoeMlp::new(
765 experts_all,
766 config.clone(),
767 dtype,
768 &device,
769 layer,
770 gate_vb.as_ref(),
771 )?);
772 }
773 Ok(())
774 }
775 fn amoe_supported(&self) -> bool {
776 true
777 }
778}