1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9 ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24 },
25 serde_default_fn,
26 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29fn default_max_position_embeddings() -> usize {
30 4096
31}
32
33serde_default_fn!(bool, word_emb_default, false);
34
35#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)]
36pub struct Config {
37 pub attention_bias: bool,
38 pub head_dim: usize,
39 pub hidden_act: Option<Activation>,
41 pub hidden_activation: Option<Activation>,
42 pub hidden_size: usize,
43 pub intermediate_size: usize,
44 pub num_attention_heads: usize,
45 pub num_hidden_layers: usize,
46 pub num_key_value_heads: usize,
47 pub rms_norm_eps: f64,
48 pub rope_theta: f64,
49 pub vocab_size: usize,
50
51 #[serde(default = "default_max_position_embeddings")]
52 pub max_position_embeddings: usize,
53 pub use_flash_attn: bool,
54 pub quantization_config: Option<QuantizedConfig>,
55 #[serde(default = "word_emb_default")]
56 #[allow(dead_code)]
57 pub tie_word_embeddings: bool,
58}
59
60impl Config {
61 pub fn hidden_act(&self) -> Result<Activation> {
62 match (self.hidden_act, self.hidden_activation) {
63 (None, Some(act)) | (Some(act), None) => Ok(act),
64 (Some(_), Some(_)) => {
65 candle_core::bail!("both hidden_act and hidden_activation are set")
66 }
67 (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
68 }
69 }
70}
71
72struct Attention {
73 q_proj: Arc<dyn QuantMethod>,
74 k_proj: Arc<dyn QuantMethod>,
75 v_proj: Arc<dyn QuantMethod>,
76 o_proj: Arc<dyn QuantMethod>,
77 num_heads: usize,
78 num_kv_heads: usize,
79 head_dim: usize,
80 rotary_emb: Arc<RotaryEmbedding>,
81 paged_attn: Option<PagedAttention>,
82 sdpa_params: SdpaParams,
83}
84
85impl Attention {
86 fn new(
87 rotary_emb: Arc<RotaryEmbedding>,
88 cfg: &Config,
89 vb: ShardedVarBuilder,
90 paged_attn: Option<PagedAttention>,
91 comm: &Arc<mistralrs_quant::Comm>,
92 ) -> Result<Self> {
93 let hidden_sz = cfg.hidden_size;
94 let num_heads = cfg.num_attention_heads;
95 let num_kv_heads = cfg.num_key_value_heads;
96 let head_dim = cfg.head_dim;
97 let bias = cfg.attention_bias;
98 let q_proj = ColumnParallelLayer::new(
99 hidden_sz,
100 num_heads * head_dim,
101 &cfg.quantization_config,
102 bias,
103 comm,
104 vb.pp("q_proj"),
105 )?;
106 let kv_shard = mistralrs_quant::compute_kv_shard(
107 cfg.num_key_value_heads,
108 cfg.hidden_size / cfg.num_attention_heads,
109 comm,
110 );
111 let k_proj = ColumnParallelLayer::new_with_shard(
112 hidden_sz,
113 num_kv_heads * head_dim,
114 &cfg.quantization_config,
115 bias,
116 comm,
117 kv_shard,
118 vb.pp("k_proj"),
119 )?;
120 let v_proj = ColumnParallelLayer::new_with_shard(
121 hidden_sz,
122 num_kv_heads * head_dim,
123 &cfg.quantization_config,
124 bias,
125 comm,
126 kv_shard,
127 vb.pp("v_proj"),
128 )?;
129 let o_proj = RowParallelLayer::new(
130 num_heads * head_dim,
131 hidden_sz,
132 &cfg.quantization_config,
133 bias,
134 comm,
135 vb.pp("o_proj"),
136 )?;
137 Ok(Self {
138 q_proj,
139 k_proj,
140 v_proj,
141 o_proj,
142 num_heads: num_heads / comm.world_size(),
143 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
144 head_dim,
145 rotary_emb,
146 paged_attn,
147 sdpa_params: SdpaParams {
148 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
149 cfg.num_key_value_heads,
150 cfg.num_attention_heads,
151 comm,
152 ),
153 use_flash_attn: cfg.use_flash_attn,
154 softcap: None,
155 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
156 sliding_window: None,
157 },
158 })
159 }
160
161 #[allow(clippy::too_many_arguments)]
162 fn forward(
163 &self,
164 xs: &Tensor,
165 attention_mask: Option<&Tensor>,
166 seqlen_offsets: &[usize],
167 kv_cache: &mut KvCache,
168 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
169 flash_params: &FlashParams,
170 ) -> Result<Tensor> {
171 let (b_sz, q_len, _) = xs.dims3()?;
172
173 let original_dtype = xs.dtype();
174 let mut xs = xs.clone();
175 if let Some(t) = self.q_proj.quantized_act_type() {
176 xs = xs.to_dtype(t)?;
177 }
178 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
179 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
180 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
181 if self.q_proj.quantized_act_type().is_some() {
182 q = q.to_dtype(original_dtype)?;
183 k = k.to_dtype(original_dtype)?;
184 v = v.to_dtype(original_dtype)?;
185 }
186
187 let (q, k, v) = if q_len != 1 {
188 let q = q
189 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
190 .transpose(1, 2)?;
191 let k = k
192 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
193 .transpose(1, 2)?;
194 let v = v
195 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196 .transpose(1, 2)?;
197 (q, k, v)
198 } else {
199 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
200 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
201 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
202 (q, k, v)
203 };
204
205 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
206
207 let mut attn_output = match &self.paged_attn {
208 Some(paged_attn) => match metadata {
209 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
210 &q,
211 &k,
212 &v,
213 attention_mask,
214 Some(key_cache),
215 Some(value_cache),
216 input_metadata,
217 &self.sdpa_params,
218 Some(flash_params),
219 )?,
220 None => {
221 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
224 assert!(attention_mask.is_some());
226 paged_attn.forward(
227 &q,
228 &k,
229 &v,
230 attention_mask,
231 None,
232 None,
233 &input_metadata,
234 &self.sdpa_params,
235 Some(flash_params),
236 )?
237 }
238 },
239 None => {
240 let (k, v) = kv_cache.append(&k, &v)?;
241
242 Sdpa.run_attention(
243 &q,
244 &k,
245 &v,
246 attention_mask,
247 Some(flash_params),
248 &self.sdpa_params,
249 )?
250 }
251 };
252
253 if let Some(t) = self.q_proj.quantized_act_type() {
254 attn_output = attn_output.to_dtype(t)?;
255 }
256 attn_output = if attention_mask.is_some() {
257 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
258 } else {
259 attn_output.reshape((b_sz, q_len, ()))?
260 };
261 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
262 if self.q_proj.quantized_act_type().is_some() {
263 res = res.to_dtype(original_dtype)?;
264 }
265 Ok(res)
266 }
267}
268
269struct DecoderLayer {
270 self_attn: Attention,
271 mlp: Box<dyn MlpLayer>,
272 input_layernorm: RmsNorm,
273 post_attention_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277 #[allow(clippy::too_many_arguments)]
278 fn new(
279 rotary_emb: Arc<RotaryEmbedding>,
280 cfg: &Config,
281 vb: ShardedVarBuilder,
282 mapper: &dyn DeviceMapper,
283 layer_idx: usize,
284 loading_isq: bool,
285 paged_attn: Option<PagedAttention>,
286 comm: &Arc<mistralrs_quant::Comm>,
287 ) -> Result<Self> {
288 let self_attn = Attention::new(
289 rotary_emb,
290 cfg,
291 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
292 paged_attn,
293 comm,
294 )?;
295 let mlp = Mlp::new(
296 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
297 cfg.hidden_size,
298 cfg.intermediate_size,
299 &cfg.quantization_config,
300 cfg.hidden_act()?,
301 comm,
302 )?;
303 let input_layernorm = RmsNorm::new_gemma(
304 cfg.hidden_size,
305 cfg.rms_norm_eps,
306 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
307 )?;
308 let post_attention_layernorm = RmsNorm::new_gemma(
309 cfg.hidden_size,
310 cfg.rms_norm_eps,
311 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
312 )?;
313 Ok(Self {
314 self_attn,
315 mlp: Box::new(mlp),
316 input_layernorm,
317 post_attention_layernorm,
318 })
319 }
320
321 #[allow(clippy::too_many_arguments)]
322 fn forward(
323 &self,
324 xs: &Tensor,
325 attention_mask: Option<&Tensor>,
326 seqlen_offsets: &[usize],
327 kv_cache: &mut KvCache,
328 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
329 flash_params: &FlashParams,
330 ) -> Result<Tensor> {
331 let residual = xs;
332 let xs = self.input_layernorm.forward(xs)?;
333 let xs = self.self_attn.forward(
334 &xs,
335 attention_mask,
336 seqlen_offsets,
337 kv_cache,
338 metadata,
339 flash_params,
340 )?;
341 let xs = (xs + residual)?;
342 let residual = &xs;
343 let xs = self
344 .mlp
345 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
346 residual + xs
347 }
348}
349
350pub struct Model {
351 embed_tokens: candle_nn::Embedding,
352 layers: Vec<DecoderLayer>,
353 norm: RmsNorm,
354 lm_head: Arc<dyn QuantMethod>,
355 hidden_size: usize,
356 device: Device,
357 cache: EitherCache,
358 max_seq_len: usize,
359 mapper: Box<dyn DeviceMapper + Send + Sync>,
360 cfg: ModelConfigMetadata,
361}
362
363impl Model {
364 pub fn new(
365 cfg: &Config,
366 vb: ShardedVarBuilder,
367 is_gptx: bool,
368 normal_loading_metadata: NormalLoadingMetadata,
369 attention_mechanism: AttentionImplementation,
370 ) -> Result<Self> {
371 if let Some(ref quant_cfg) = &cfg.quantization_config {
372 tracing::info!(
373 "Using {} quantization: {}.",
374 quant_cfg.name(),
375 quant_cfg.get_bits_name(&vb)
376 );
377 }
378 let mapper = normal_loading_metadata.mapper;
379
380 let vb_m = vb.pp("model");
381 let embed_tokens = embedding(
382 cfg.vocab_size,
383 cfg.hidden_size,
384 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
385 &cfg.quantization_config,
386 )?;
387 let mut ropes = HashMap::new();
388 for layer_idx in 0..cfg.num_hidden_layers {
389 let device = mapper
390 .device_for(layer_idx, false)
391 .unwrap_or(&normal_loading_metadata.real_device);
392 ropes.insert(
393 device.location(),
394 Arc::new(RotaryEmbedding::new(
395 cfg.rope_theta as f32,
396 cfg.head_dim,
397 cfg.max_position_embeddings,
398 device,
399 is_gptx,
400 vb_m.dtype(),
401 )?),
402 );
403 }
404 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
405 let vb_l = vb_m.pp("layers");
406 for layer_idx in NiceProgressBar::<_, 'b'>(
407 0..cfg.num_hidden_layers,
408 "Loading repeating layers",
409 &normal_loading_metadata.multi_progress,
410 ) {
411 let device = mapper
412 .device_for(layer_idx, false)
413 .unwrap_or(&normal_loading_metadata.real_device);
414 let rotary_emb = ropes
415 .get(&device.location())
416 .expect("No RoPE for device location!")
417 .clone();
418 let paged_attn = match &attention_mechanism {
419 AttentionImplementation::Eager => None,
420 AttentionImplementation::PagedAttention => {
421 Some(PagedAttention::new(cfg.head_dim, device, None)?)
422 }
423 };
424 let comm = mapper.get_comm_for(layer_idx)?;
425 let layer = DecoderLayer::new(
426 rotary_emb.clone(),
427 cfg,
428 vb_l.pp(layer_idx),
429 &*mapper,
430 layer_idx,
431 normal_loading_metadata.loading_isq,
432 paged_attn,
433 &comm,
434 )?;
435 layers.push(layer)
436 }
437 let norm = RmsNorm::new_gemma(
438 cfg.hidden_size,
439 cfg.rms_norm_eps,
440 mapper.set_nm_device(vb_m.pp("norm"), false),
441 )?;
442 let lm_head = mapper.cast_nm_device(
443 embed_tokens.embeddings(),
444 normal_loading_metadata.loading_isq,
445 )?;
446 Ok(Self {
447 embed_tokens,
448 layers,
449 norm,
450 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
451 Linear::new(lm_head, None),
452 ))?),
453 device: normal_loading_metadata.real_device,
454 hidden_size: cfg.hidden_size,
455 cache: EitherCache::Normal(NormalCache::new(
456 cfg.num_hidden_layers,
457 cfg.max_position_embeddings,
458 )),
459 max_seq_len: default_max_position_embeddings(),
460 cfg: ModelConfigMetadata {
461 max_seq_len: cfg.max_position_embeddings,
462 num_layers: cfg.num_hidden_layers,
463 hidden_size: cfg.hidden_size,
464 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
465 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
466 .max(1),
467 sliding_window: None,
468 k_head_dim: cfg.head_dim,
469 v_head_dim: cfg.head_dim,
470 },
471 mapper,
472 })
473 }
474
475 pub fn forward(
476 &self,
477 input_ids: &Tensor,
478 seqlen_offsets: &[usize],
479 context_lens: Vec<(usize, usize)>,
480 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
481 flash_params: &FlashParams,
482 ) -> Result<Tensor> {
483 let xs = self.embed_tokens.forward(input_ids)?;
484 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
485 let cache = &mut self.cache.normal().0;
486 let attention_mask = CausalMasker.make_causal_mask_matrix(
487 input_ids,
488 metadata
489 .as_ref()
490 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
491 .unwrap_or(cache as &dyn PastKvLenCache),
492 xs.dtype(),
493 self.cfg.num_attn_heads,
494 )?;
495 let attention_mask = attention_mask.filter(|_| {
497 metadata
498 .as_ref()
499 .map(|(_, meta)| meta.is_first_prompt_chunk)
500 .unwrap_or(true)
501 });
502 for (i, layer) in self.layers.iter().enumerate() {
503 xs = self.mapper.map(xs, i)?;
504 xs = layer.forward(
505 &xs,
506 attention_mask
507 .as_ref()
508 .map(|m| m.to_device(xs.device()).unwrap())
509 .as_ref(),
510 seqlen_offsets,
511 &mut cache[i],
512 metadata
513 .as_ref()
514 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
515 flash_params,
516 )?;
517 }
518 let xs = xs.to_device(&self.device)?;
519 let mut xs = xs.apply(&self.norm)?;
520 if let Some(t) = self.lm_head.quantized_act_type() {
521 xs = xs.to_dtype(t)?;
522 }
523 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
524 }
525}
526
527impl IsqModel for Model {
528 fn get_layers(
529 &mut self,
530 ) -> (
531 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
532 &dyn DeviceMapper,
533 ) {
534 let mut tensors = Vec::new();
535 tensors.push((&mut self.lm_head, None));
536 for (i, layer) in self.layers.iter_mut().enumerate() {
537 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
538 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
539 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
540 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
541 tensors.extend(
542 layer
543 .mlp
544 .get_isq_layers()
545 .into_iter()
546 .map(|m| (m, Some(i)))
547 .collect::<Vec<_>>(),
548 );
549 }
550 (tensors, &*self.mapper)
551 }
552
553 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
554 let uvb = UnVarBuilder::new();
555
556 let uvb_m = uvb.pp("model");
557 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
558 uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
559
560 for (layer_idx, layer) in self.layers.iter().enumerate() {
561 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
562 uvb_l
563 .pp("input_layernorm")
564 .add(&layer.input_layernorm.undo_gemma().unwrap());
565 uvb_l
566 .pp("post_attention_layernorm")
567 .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
568 }
569
570 uvb.to_safetensors()
571 }
572
573 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
574 let mut names = Vec::new();
576 names.push(None);
578 for i in 0..self.layers.len() {
579 names.push(Some(format!("blk.{i}.attn_q.weight")));
580 names.push(Some(format!("blk.{i}.attn_k.weight")));
581 names.push(Some(format!("blk.{i}.attn_v.weight")));
582 names.push(Some(format!("blk.{i}.attn_output.weight")));
583 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
584 names.push(Some(format!("blk.{i}.ffn_up.weight")));
585 names.push(Some(format!("blk.{i}.ffn_down.weight")));
586 }
587 Ok(names)
588 }
589}
590
591impl NormalModel for Model {
592 fn forward(
593 &self,
594 input_ids: &Tensor,
595 seqlen_offsets: &[usize],
596 context_lens: Vec<(usize, usize)>,
597 _position_ids: Vec<usize>,
598 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
599 flash_params: &FlashParams,
600 ) -> Result<Tensor> {
601 self.forward(
602 input_ids,
603 seqlen_offsets,
604 context_lens,
605 metadata,
606 flash_params,
607 )
608 }
609 fn xlora_forward(
610 &self,
611 _input_ids: &Tensor,
612 _input_ids_full: &Tensor,
613 _seqlen_offsets: &[usize],
614 _seqlen_offsets_full: &[usize],
615 _no_kv_cache: bool,
616 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
617 _context_lens: Vec<(usize, usize)>,
618 _position_ids: Vec<usize>,
619 _flash_params: &FlashParams,
620 _flash_params_full: &FlashParams,
621 ) -> Result<Tensor> {
622 unimplemented!()
623 }
624 fn cache(&self) -> &EitherCache {
625 &self.cache
626 }
627 fn cache_mut(&mut self) -> &mut EitherCache {
628 &mut self.cache
629 }
630 fn device(&self) -> &Device {
631 &self.device
632 }
633 fn is_xlora(&self) -> bool {
634 false
635 }
636 fn max_seq_len(&self) -> usize {
637 self.max_seq_len
638 }
639 fn config(&self) -> &ModelConfigMetadata {
640 &self.cfg
641 }
642}
643
644impl AnyMoeBaseModelMixin for Model {
645 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
646 let mut mlps = Vec::new();
647 for layer in &self.layers {
648 mlps.push(&*layer.mlp);
649 }
650 mlps
651 }
652 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
653 let mut mlps = Vec::new();
654 for layer in &mut self.layers {
655 mlps.push(&mut layer.mlp);
656 }
657 mlps
658 }
659 fn create_anymoe_layers(
660 &mut self,
661 additional_vbs: Vec<ShardedVarBuilder>,
662 config: AnyMoeConfig,
663 (prefix, mlp): (String, String),
664 mut layers: Vec<usize>,
665 expert_type: AnyMoeExpertType,
666 gate_vb: Option<ShardedVarBuilder>,
667 ) -> Result<()> {
668 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
669 if layers.is_empty() {
670 layers = (0..self.layers.len()).collect::<Vec<_>>();
671 }
672 for _ in 0..layers.len() {
673 experts.push(Vec::new());
674 }
675 for vb in additional_vbs {
676 let vb = vb.pp(&prefix);
677 for (layer, row) in experts.iter_mut().enumerate() {
678 if !layers.contains(&layer) {
679 continue;
680 }
681
682 let intermediate_size = self.layers[layer].mlp.get_params()[1];
683 let hidden_size = self.layers[layer].mlp.get_params()[0];
684 match expert_type {
685 AnyMoeExpertType::FineTuned => {
686 let (dtype, device) = self.layers[layer].mlp.dtype_device();
687 row.push(Box::new(Mlp::replicate(
688 self.layers[layer].mlp.get_params(),
689 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
690 self.layers[layer].mlp.hidden_act(),
691 &self.mapper.get_comm_for(layer)?,
692 )?));
693 }
694 AnyMoeExpertType::LoraAdapter {
695 rank,
696 alpha,
697 ref target_modules,
698 } => {
699 let vb_mlp = vb.pp(layer).pp(&mlp);
700
701 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
702 Some(get_delta_from_lora_ab!(
703 vb_mlp,
704 rank,
705 alpha,
706 (hidden_size, intermediate_size),
707 "gate_proj"
708 ))
709 } else {
710 None
711 };
712 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
713 Some(get_delta_from_lora_ab!(
714 vb_mlp,
715 rank,
716 alpha,
717 (hidden_size, intermediate_size),
718 "up_proj"
719 ))
720 } else {
721 None
722 };
723 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
724 Some(get_delta_from_lora_ab!(
725 vb_mlp,
726 rank,
727 alpha,
728 (intermediate_size, hidden_size),
729 "down_proj"
730 ))
731 } else {
732 None
733 };
734
735 row.push(self.layers[layer].mlp.new_added_delta(vec![
736 gate_proj_delta,
737 up_proj_delta,
738 down_proj_delta,
739 ])?);
740 }
741 }
742 }
743 }
744 for (layer, expert) in layers.into_iter().zip(experts) {
745 let mut experts_all = vec![self.layers[layer].mlp.clone()];
746 experts_all.extend(expert);
747 let (dtype, device) = self.layers[layer].mlp.dtype_device();
748 self.layers[layer].mlp = Box::new(MoeMlp::new(
749 experts_all,
750 config.clone(),
751 dtype,
752 &device,
753 layer,
754 gate_vb.as_ref(),
755 )?);
756 }
757 Ok(())
758 }
759 fn amoe_supported(&self) -> bool {
760 true
761 }
762}