1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use mistralrs_quant::{
5 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
6 ShardedVarBuilder,
7};
8use std::{collections::HashMap, sync::Arc};
9
10use crate::{
11 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 get_delta_from_lora_ab,
15 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
16 layers_masker::PastKvLenCache,
17 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
22 },
23 serde_default_fn,
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27serde_default_fn!(bool, word_emb_default, false);
28serde_default_fn!(bool, use_flash_attn, false);
29
30#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)]
31pub struct Config {
32 pub vocab_size: usize,
33 pub hidden_size: usize,
34 pub intermediate_size: usize,
35 pub num_hidden_layers: usize,
36 pub num_attention_heads: usize,
37 pub num_key_value_heads: usize,
38 pub max_position_embeddings: usize,
39 pub sliding_window: Option<usize>,
40 pub rope_theta: f64,
41 pub rms_norm_eps: f64,
42 pub hidden_act: Activation,
43 #[serde(default = "use_flash_attn")]
44 pub use_flash_attn: bool,
45 pub quantization_config: Option<QuantizedConfig>,
46 #[serde(default = "word_emb_default")]
47 pub tie_word_embeddings: bool,
48}
49
50struct Attention {
51 q_proj: Arc<dyn QuantMethod>,
52 k_proj: Arc<dyn QuantMethod>,
53 v_proj: Arc<dyn QuantMethod>,
54 o_proj: Arc<dyn QuantMethod>,
55 num_heads: usize,
56 num_kv_heads: usize,
57 head_dim: usize,
58 rotary_emb: Arc<RotaryEmbedding>,
59 paged_attn: Option<PagedAttention>,
60 sdpa_params: SdpaParams,
61}
62
63impl Attention {
64 fn new(
65 rotary_emb: Arc<RotaryEmbedding>,
66 cfg: &Config,
67 vb: ShardedVarBuilder,
68 paged_attn: Option<PagedAttention>,
69 comm: &Arc<mistralrs_quant::Comm>,
70 ) -> Result<Self> {
71 let hidden_sz = cfg.hidden_size;
72 let num_heads = cfg.num_attention_heads;
73 let num_kv_heads = cfg.num_key_value_heads;
74 let head_dim = hidden_sz / num_heads;
75 let q_proj = ColumnParallelLayer::new(
76 hidden_sz,
77 num_heads * head_dim,
78 &cfg.quantization_config,
79 true,
80 comm,
81 vb.pp("q_proj"),
82 )?;
83 let kv_shard = mistralrs_quant::compute_kv_shard(
84 cfg.num_key_value_heads,
85 cfg.hidden_size / cfg.num_attention_heads,
86 comm,
87 );
88 let k_proj = ColumnParallelLayer::new_with_shard(
89 hidden_sz,
90 num_kv_heads * head_dim,
91 &cfg.quantization_config,
92 true,
93 comm,
94 kv_shard,
95 vb.pp("k_proj"),
96 )?;
97 let v_proj = ColumnParallelLayer::new_with_shard(
98 hidden_sz,
99 num_kv_heads * head_dim,
100 &cfg.quantization_config,
101 true,
102 comm,
103 kv_shard,
104 vb.pp("v_proj"),
105 )?;
106 let o_proj = RowParallelLayer::new(
107 num_heads * head_dim,
108 hidden_sz,
109 &cfg.quantization_config,
110 false,
111 comm,
112 vb.pp("o_proj"),
113 )?;
114 Ok(Self {
115 q_proj,
116 k_proj,
117 v_proj,
118 o_proj,
119 num_heads: num_heads / comm.world_size(),
120 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
121 head_dim,
122 rotary_emb,
123 paged_attn,
124 sdpa_params: SdpaParams {
125 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
126 cfg.num_key_value_heads,
127 cfg.num_attention_heads,
128 comm,
129 ),
130 use_flash_attn: cfg.use_flash_attn,
131 softcap: None,
132 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
133 sliding_window: cfg.sliding_window,
134 },
135 })
136 }
137
138 #[allow(clippy::too_many_arguments)]
139 fn forward(
140 &self,
141 xs: &Tensor,
142 attention_mask: Option<&Tensor>,
143 seqlen_offsets: &[usize],
144 kv_cache: &mut KvCache,
145 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
146 flash_params: &FlashParams,
147 ) -> Result<Tensor> {
148 let (b_sz, q_len, _) = xs.dims3()?;
149
150 let original_dtype = xs.dtype();
151 let mut xs = xs.clone();
152 if let Some(t) = self.q_proj.quantized_act_type() {
153 xs = xs.to_dtype(t)?;
154 }
155 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
156 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
157 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
158 if self.q_proj.quantized_act_type().is_some() {
159 q = q.to_dtype(original_dtype)?;
160 k = k.to_dtype(original_dtype)?;
161 v = v.to_dtype(original_dtype)?;
162 }
163
164 let (q, k, v) = if q_len != 1 {
165 let q = q
166 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
167 .transpose(1, 2)?;
168 let k = k
169 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
170 .transpose(1, 2)?;
171 let v = v
172 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173 .transpose(1, 2)?;
174 (q, k, v)
175 } else {
176 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
177 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179 (q, k, v)
180 };
181
182 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
183
184 let mut attn_output = match &self.paged_attn {
185 Some(paged_attn) => match metadata {
186 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
187 &q,
188 &k,
189 &v,
190 attention_mask,
191 Some(key_cache),
192 Some(value_cache),
193 input_metadata,
194 &self.sdpa_params,
195 Some(flash_params),
196 )?,
197 None => {
198 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
201 assert!(attention_mask.is_some());
203 paged_attn.forward(
204 &q,
205 &k,
206 &v,
207 attention_mask,
208 None,
209 None,
210 &input_metadata,
211 &self.sdpa_params,
212 Some(flash_params),
213 )?
214 }
215 },
216 None => {
217 let (k, v) = kv_cache.append(&k, &v)?;
218
219 Sdpa.run_attention(
220 &q,
221 &k,
222 &v,
223 attention_mask,
224 Some(flash_params),
225 &self.sdpa_params,
226 )?
227 }
228 };
229
230 if let Some(t) = self.q_proj.quantized_act_type() {
231 attn_output = attn_output.to_dtype(t)?;
232 }
233 attn_output = if attention_mask.is_some() {
234 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
235 } else {
236 attn_output.reshape((b_sz, q_len, ()))?
237 };
238 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
239 if self.q_proj.quantized_act_type().is_some() {
240 res = res.to_dtype(original_dtype)?;
241 }
242 Ok(res)
243 }
244}
245
246struct DecoderLayer {
247 self_attn: Attention,
248 mlp: Box<dyn MlpLayer>,
249 input_layernorm: RmsNorm,
250 post_attention_layernorm: RmsNorm,
251}
252
253impl DecoderLayer {
254 #[allow(clippy::too_many_arguments)]
255 fn new(
256 rotary_emb: Arc<RotaryEmbedding>,
257 cfg: &Config,
258 vb: ShardedVarBuilder,
259 mapper: &dyn DeviceMapper,
260 layer_idx: usize,
261 loading_isq: bool,
262 paged_attn: Option<PagedAttention>,
263 comm: &Arc<mistralrs_quant::Comm>,
264 ) -> Result<Self> {
265 let self_attn = Attention::new(
266 rotary_emb,
267 cfg,
268 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
269 paged_attn,
270 comm,
271 )?;
272 let mlp = Mlp::new(
273 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
274 cfg.hidden_size,
275 cfg.intermediate_size,
276 &cfg.quantization_config,
277 cfg.hidden_act,
278 comm,
279 )?;
280 let input_layernorm = RmsNorm::new(
281 cfg.hidden_size,
282 cfg.rms_norm_eps,
283 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
284 )?;
285 let post_attention_layernorm = RmsNorm::new(
286 cfg.hidden_size,
287 cfg.rms_norm_eps,
288 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
289 )?;
290 Ok(Self {
291 self_attn,
292 mlp: Box::new(mlp),
293 input_layernorm,
294 post_attention_layernorm,
295 })
296 }
297
298 #[allow(clippy::too_many_arguments)]
299 fn forward(
300 &self,
301 xs: &Tensor,
302 attention_mask: Option<&Tensor>,
303 seqlen_offsets: &[usize],
304 kv_cache: &mut KvCache,
305 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
306 flash_params: &FlashParams,
307 ) -> Result<Tensor> {
308 let residual = xs;
309 let xs = self.input_layernorm.forward(xs)?;
310 let xs = self.self_attn.forward(
311 &xs,
312 attention_mask,
313 seqlen_offsets,
314 kv_cache,
315 metadata,
316 flash_params,
317 )?;
318 let xs = (xs + residual)?;
319 let residual = &xs;
320 let xs = self
321 .mlp
322 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
323 residual + xs
324 }
325}
326
327pub struct Model {
328 embed_tokens: candle_nn::Embedding,
329 layers: Vec<DecoderLayer>,
330 norm: RmsNorm,
331 lm_head: Arc<dyn QuantMethod>,
332 sliding_window: Option<usize>,
333 device: Device,
334 cache: EitherCache,
335 max_seq_len: usize,
336 mapper: Box<dyn DeviceMapper + Send + Sync>,
337 cfg: ModelConfigMetadata,
338}
339
340impl Model {
341 pub fn new(
342 cfg: &Config,
343 vb: ShardedVarBuilder,
344 is_gptx: bool,
345 normal_loading_metadata: NormalLoadingMetadata,
346 attention_mechanism: AttentionImplementation,
347 ) -> Result<Self> {
348 if let Some(ref quant_cfg) = &cfg.quantization_config {
349 tracing::info!(
350 "Using {} quantization: {}.",
351 quant_cfg.name(),
352 quant_cfg.get_bits_name(&vb)
353 );
354 }
355 let mapper = normal_loading_metadata.mapper;
356 let vb_m = vb.pp("model");
357
358 let embed_tokens = embedding(
359 cfg.vocab_size,
360 cfg.hidden_size,
361 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
362 &cfg.quantization_config,
363 )?;
364 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
365 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
366
367 let mut ropes = HashMap::new();
368 for layer_idx in 0..cfg.num_hidden_layers {
369 let device = mapper
370 .device_for(layer_idx, false)
371 .unwrap_or(&normal_loading_metadata.real_device);
372 ropes.insert(
373 device.location(),
374 Arc::new(RotaryEmbedding::new(
375 cfg.rope_theta as f32,
376 head_dim,
377 cfg.max_position_embeddings,
378 device,
379 is_gptx,
380 vb_m.dtype(),
381 )?),
382 );
383 }
384
385 let vb_l = vb_m.pp("layers");
386 for layer_idx in NiceProgressBar::<_, 'b'>(
387 0..cfg.num_hidden_layers,
388 "Loading repeating layers",
389 &normal_loading_metadata.multi_progress,
390 ) {
391 let device = mapper
392 .device_for(layer_idx, false)
393 .unwrap_or(&normal_loading_metadata.real_device);
394 let rotary_emb = ropes
395 .get(&device.location())
396 .expect("No RoPE for device location!")
397 .clone();
398 let paged_attn = match &attention_mechanism {
399 AttentionImplementation::Eager => None,
400 AttentionImplementation::PagedAttention => {
401 Some(PagedAttention::new(head_dim, device, None)?)
402 }
403 };
404 let comm = mapper.get_comm_for(layer_idx)?;
405 let layer = DecoderLayer::new(
406 rotary_emb.clone(),
407 cfg,
408 vb_l.pp(layer_idx),
409 &*mapper,
410 layer_idx,
411 normal_loading_metadata.loading_isq,
412 paged_attn,
413 &comm,
414 )?;
415 layers.push(layer)
416 }
417 let norm = RmsNorm::new(
418 cfg.hidden_size,
419 cfg.rms_norm_eps,
420 mapper.set_nm_device(vb_m.pp("norm"), false),
421 )?;
422 let lm_head = if !cfg.tie_word_embeddings {
423 ReplicatedLayer::new(
424 cfg.hidden_size,
425 cfg.vocab_size,
426 &None,
427 false,
428 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
429 )?
430 } else {
431 ReplicatedLayer::from_linear(candle_nn::Linear::new(
432 mapper.cast_nm_device(
433 embed_tokens.embeddings(),
434 normal_loading_metadata.loading_isq,
435 )?,
436 None,
437 ))?
438 };
439 Ok(Self {
440 embed_tokens,
441 layers,
442 norm,
443 lm_head,
444 sliding_window: cfg.sliding_window,
445 device: normal_loading_metadata.real_device,
446 cache: EitherCache::Normal(NormalCache::new(
447 cfg.num_hidden_layers,
448 cfg.max_position_embeddings,
449 )),
450 max_seq_len: cfg.max_position_embeddings,
451 cfg: ModelConfigMetadata {
452 max_seq_len: cfg.max_position_embeddings,
453 num_layers: cfg.num_hidden_layers,
454 hidden_size: cfg.hidden_size,
455 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
456 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
457 .max(1),
458 sliding_window: cfg.sliding_window,
459 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
460 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
461 },
462 mapper,
463 })
464 }
465
466 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
467 self.embed_tokens.forward(input_ids)
468 }
469
470 pub fn forward(
471 &self,
472 input_ids: &Tensor,
473 seqlen_offsets: &[usize],
474 context_lens: Vec<(usize, usize)>,
475 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
476 flash_params: &FlashParams,
477 ) -> Result<Tensor> {
478 let xs = self.embed_tokens.forward(input_ids)?;
479 self.forward_embed(
480 input_ids,
481 xs,
482 seqlen_offsets,
483 context_lens,
484 metadata,
485 flash_params,
486 )
487 }
488
489 #[allow(clippy::too_many_arguments)]
490 pub fn forward_embed(
491 &self,
492 input_ids: &Tensor,
493 mut xs: Tensor,
494 seqlen_offsets: &[usize],
495 context_lens: Vec<(usize, usize)>,
496 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
497 flash_params: &FlashParams,
498 ) -> Result<Tensor> {
499 let cache = &mut self.cache.normal().0;
500 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
501 input_ids,
502 metadata
503 .as_ref()
504 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
505 .unwrap_or(cache as &dyn PastKvLenCache),
506 self.sliding_window,
507 xs.dtype(),
508 self.cfg.num_attn_heads,
509 )?;
510 let attention_mask = attention_mask.filter(|_| {
511 metadata
512 .as_ref()
513 .map(|(_, meta)| meta.is_first_prompt_chunk)
514 .unwrap_or(true)
515 });
516 for (i, layer) in self.layers.iter().enumerate() {
517 xs = self.mapper.map(xs, i)?;
518 xs = layer.forward(
519 &xs,
520 attention_mask
521 .as_ref()
522 .map(|m| m.to_device(xs.device()).unwrap())
523 .as_ref(),
524 seqlen_offsets,
525 &mut cache[i],
526 metadata
527 .as_ref()
528 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
529 flash_params,
530 )?
531 }
532 let xs = xs.to_device(&self.device)?;
533 let mut xs = xs.apply(&self.norm)?;
534 if let Some(t) = self.lm_head.quantized_act_type() {
535 xs = xs.to_dtype(t)?;
536 }
537 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
538 }
539
540 pub fn embed_dtype(&self) -> DType {
541 self.embed_tokens.embeddings().dtype()
542 }
543}
544
545impl IsqModel for Model {
546 fn get_layers(
547 &mut self,
548 ) -> (
549 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
550 &dyn DeviceMapper,
551 ) {
552 let mut tensors = Vec::new();
553 tensors.push((&mut self.lm_head, None));
554 for (i, layer) in self.layers.iter_mut().enumerate() {
555 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
556 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
557 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
558 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
559 tensors.extend(
560 layer
561 .mlp
562 .get_isq_layers()
563 .into_iter()
564 .map(|m| (m, Some(i)))
565 .collect::<Vec<_>>(),
566 );
567 }
568 (tensors, &*self.mapper)
569 }
570
571 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
572 let uvb = UnVarBuilder::new();
573
574 let uvb_m = uvb.pp("model");
575 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
576 uvb_m.pp("norm").add(&self.norm);
577
578 for (layer_idx, layer) in self.layers.iter().enumerate() {
579 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
580 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
581 uvb_l
582 .pp("post_attention_layernorm")
583 .add(&layer.post_attention_layernorm);
584 }
585
586 uvb.to_safetensors()
587 }
588
589 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
590 let mut names = Vec::new();
592 names.push(None);
594 for i in 0..self.layers.len() {
595 names.push(Some(format!("blk.{i}.attn_q.weight")));
596 names.push(Some(format!("blk.{i}.attn_k.weight")));
597 names.push(Some(format!("blk.{i}.attn_v.weight")));
598 names.push(Some(format!("blk.{i}.attn_output.weight")));
599 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
600 names.push(Some(format!("blk.{i}.ffn_up.weight")));
601 names.push(Some(format!("blk.{i}.ffn_down.weight")));
602 }
603 Ok(names)
604 }
605}
606
607impl NormalModel for Model {
608 fn forward(
609 &self,
610 input_ids: &Tensor,
611 seqlen_offsets: &[usize],
612 context_lens: Vec<(usize, usize)>,
613 _position_ids: Vec<usize>,
614 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
615 flash_params: &FlashParams,
616 ) -> Result<Tensor> {
617 self.forward(
618 input_ids,
619 seqlen_offsets,
620 context_lens,
621 metadata,
622 flash_params,
623 )
624 }
625 fn xlora_forward(
626 &self,
627 _input_ids: &Tensor,
628 _input_ids_full: &Tensor,
629 _seqlen_offsets: &[usize],
630 _seqlen_offsets_full: &[usize],
631 _no_kv_cache: bool,
632 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
633 _context_lens: Vec<(usize, usize)>,
634 _position_ids: Vec<usize>,
635 _flash_params: &FlashParams,
636 _flash_params_full: &FlashParams,
637 ) -> Result<Tensor> {
638 unimplemented!()
639 }
640 fn cache(&self) -> &EitherCache {
641 &self.cache
642 }
643 fn cache_mut(&mut self) -> &mut EitherCache {
644 &mut self.cache
645 }
646 fn device(&self) -> &Device {
647 &self.device
648 }
649 fn is_xlora(&self) -> bool {
650 false
651 }
652 fn max_seq_len(&self) -> usize {
653 self.max_seq_len
654 }
655 fn config(&self) -> &ModelConfigMetadata {
656 &self.cfg
657 }
658}
659
660impl AnyMoeBaseModelMixin for Model {
661 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
662 let mut mlps = Vec::new();
663 for layer in &self.layers {
664 mlps.push(&*layer.mlp);
665 }
666 mlps
667 }
668 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
669 let mut mlps = Vec::new();
670 for layer in &mut self.layers {
671 mlps.push(&mut layer.mlp);
672 }
673 mlps
674 }
675 fn create_anymoe_layers(
676 &mut self,
677 additional_vbs: Vec<ShardedVarBuilder>,
678 config: AnyMoeConfig,
679 (prefix, mlp): (String, String),
680 mut layers: Vec<usize>,
681 expert_type: AnyMoeExpertType,
682 gate_vb: Option<ShardedVarBuilder>,
683 ) -> Result<()> {
684 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
685 if layers.is_empty() {
686 layers = (0..self.layers.len()).collect::<Vec<_>>();
687 }
688 for _ in 0..layers.len() {
689 experts.push(Vec::new());
690 }
691 for vb in additional_vbs {
692 let vb = vb.pp(&prefix);
693 for (layer, row) in experts.iter_mut().enumerate() {
694 if !layers.contains(&layer) {
695 continue;
696 }
697
698 let intermediate_size = self.layers[layer].mlp.get_params()[1];
699 let hidden_size = self.layers[layer].mlp.get_params()[0];
700 match expert_type {
701 AnyMoeExpertType::FineTuned => {
702 let (dtype, device) = self.layers[layer].mlp.dtype_device();
703 row.push(Box::new(Mlp::replicate(
704 self.layers[layer].mlp.get_params(),
705 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
706 self.layers[layer].mlp.hidden_act(),
707 &self.mapper.get_comm_for(layer)?,
708 )?));
709 }
710 AnyMoeExpertType::LoraAdapter {
711 rank,
712 alpha,
713 ref target_modules,
714 } => {
715 let vb_mlp = vb.pp(layer).pp(&mlp);
716
717 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
718 Some(get_delta_from_lora_ab!(
719 vb_mlp,
720 rank,
721 alpha,
722 (hidden_size, intermediate_size),
723 "gate_proj"
724 ))
725 } else {
726 None
727 };
728 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
729 Some(get_delta_from_lora_ab!(
730 vb_mlp,
731 rank,
732 alpha,
733 (hidden_size, intermediate_size),
734 "up_proj"
735 ))
736 } else {
737 None
738 };
739 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
740 Some(get_delta_from_lora_ab!(
741 vb_mlp,
742 rank,
743 alpha,
744 (intermediate_size, hidden_size),
745 "down_proj"
746 ))
747 } else {
748 None
749 };
750
751 row.push(self.layers[layer].mlp.new_added_delta(vec![
752 gate_proj_delta,
753 up_proj_delta,
754 down_proj_delta,
755 ])?);
756 }
757 }
758 }
759 }
760 for (layer, expert) in layers.into_iter().zip(experts) {
761 let mut experts_all = vec![self.layers[layer].mlp.clone()];
762 experts_all.extend(expert);
763 let (dtype, device) = self.layers[layer].mlp.dtype_device();
764 self.layers[layer].mlp = Box::new(MoeMlp::new(
765 experts_all,
766 config.clone(),
767 dtype,
768 &device,
769 layer,
770 gate_vb.as_ref(),
771 )?);
772 }
773 Ok(())
774 }
775 fn amoe_supported(&self) -> bool {
776 true
777 }
778}