1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9 ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24 },
25 serde_default_fn,
26 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29fn default_max_position_embeddings() -> usize {
30 4096
31}
32
33serde_default_fn!(bool, word_emb_default, false);
34
35#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)]
36pub struct Config {
37 pub attention_bias: bool,
38 pub head_dim: usize,
39 pub hidden_act: Option<Activation>,
41 pub hidden_activation: Option<Activation>,
42 pub hidden_size: usize,
43 pub intermediate_size: usize,
44 pub num_attention_heads: usize,
45 pub num_hidden_layers: usize,
46 pub num_key_value_heads: usize,
47 pub rms_norm_eps: f64,
48 pub rope_theta: f64,
49 pub vocab_size: usize,
50
51 #[serde(default = "default_max_position_embeddings")]
52 pub max_position_embeddings: usize,
53 pub use_flash_attn: bool,
54 pub quantization_config: Option<QuantizedConfig>,
55 #[serde(default = "word_emb_default")]
56 #[allow(dead_code)]
57 pub tie_word_embeddings: bool,
58}
59
60impl Config {
61 pub fn hidden_act(&self) -> Result<Activation> {
62 match (self.hidden_act, self.hidden_activation) {
63 (None, Some(act)) | (Some(act), None) => Ok(act),
64 (Some(_), Some(_)) => {
65 candle_core::bail!("both hidden_act and hidden_activation are set")
66 }
67 (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
68 }
69 }
70}
71
72struct Attention {
73 q_proj: Arc<dyn QuantMethod>,
74 k_proj: Arc<dyn QuantMethod>,
75 v_proj: Arc<dyn QuantMethod>,
76 o_proj: Arc<dyn QuantMethod>,
77 num_heads: usize,
78 num_kv_heads: usize,
79 head_dim: usize,
80 rotary_emb: Arc<RotaryEmbedding>,
81 paged_attn: Option<PagedAttention>,
82 sdpa_params: SdpaParams,
83}
84
85impl Attention {
86 fn new(
87 rotary_emb: Arc<RotaryEmbedding>,
88 cfg: &Config,
89 vb: ShardedVarBuilder,
90 paged_attn: Option<PagedAttention>,
91 comm: &Arc<mistralrs_quant::Comm>,
92 ) -> Result<Self> {
93 let hidden_sz = cfg.hidden_size;
94 let num_heads = cfg.num_attention_heads;
95 let num_kv_heads = cfg.num_key_value_heads;
96 let head_dim = cfg.head_dim;
97 let bias = cfg.attention_bias;
98 let q_proj = ColumnParallelLayer::new(
99 hidden_sz,
100 num_heads * head_dim,
101 &cfg.quantization_config,
102 bias,
103 comm,
104 vb.pp("q_proj"),
105 )?;
106 let kv_shard = mistralrs_quant::compute_kv_shard(
107 cfg.num_key_value_heads,
108 cfg.hidden_size / cfg.num_attention_heads,
109 comm,
110 );
111 let k_proj = ColumnParallelLayer::new_with_shard(
112 hidden_sz,
113 num_kv_heads * head_dim,
114 &cfg.quantization_config,
115 bias,
116 comm,
117 kv_shard,
118 vb.pp("k_proj"),
119 )?;
120 let v_proj = ColumnParallelLayer::new_with_shard(
121 hidden_sz,
122 num_kv_heads * head_dim,
123 &cfg.quantization_config,
124 bias,
125 comm,
126 kv_shard,
127 vb.pp("v_proj"),
128 )?;
129 let o_proj = RowParallelLayer::new(
130 num_heads * head_dim,
131 hidden_sz,
132 &cfg.quantization_config,
133 bias,
134 comm,
135 vb.pp("o_proj"),
136 )?;
137 Ok(Self {
138 q_proj,
139 k_proj,
140 v_proj,
141 o_proj,
142 num_heads: num_heads / comm.world_size(),
143 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
144 head_dim,
145 rotary_emb,
146 paged_attn,
147 sdpa_params: SdpaParams {
148 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
149 cfg.num_key_value_heads,
150 cfg.num_attention_heads,
151 comm,
152 ),
153 use_flash_attn: cfg.use_flash_attn,
154 softcap: None,
155 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
156 sliding_window: None,
157 },
158 })
159 }
160
161 #[allow(clippy::too_many_arguments)]
162 fn forward(
163 &self,
164 xs: &Tensor,
165 attention_mask: Option<&Tensor>,
166 seqlen_offsets: &[usize],
167 kv_cache: &mut KvCache,
168 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
169 flash_params: &FlashParams,
170 ) -> Result<Tensor> {
171 let (b_sz, q_len, _) = xs.dims3()?;
172
173 let original_dtype = xs.dtype();
174 let mut xs = xs.clone();
175 if let Some(t) = self.q_proj.quantized_act_type() {
176 xs = xs.to_dtype(t)?;
177 }
178 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
179 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
180 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
181 if self.q_proj.quantized_act_type().is_some() {
182 q = q.to_dtype(original_dtype)?;
183 k = k.to_dtype(original_dtype)?;
184 v = v.to_dtype(original_dtype)?;
185 }
186
187 let (q, k, v) = if q_len != 1 {
188 let q = q
189 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
190 .transpose(1, 2)?;
191 let k = k
192 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
193 .transpose(1, 2)?;
194 let v = v
195 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196 .transpose(1, 2)?;
197 (q, k, v)
198 } else {
199 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
200 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
201 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
202 (q, k, v)
203 };
204
205 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
206
207 let mut attn_output = match &self.paged_attn {
208 Some(paged_attn) => match metadata {
209 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
210 &q,
211 &k,
212 &v,
213 attention_mask,
214 Some(key_cache),
215 Some(value_cache),
216 input_metadata,
217 &self.sdpa_params,
218 Some(flash_params),
219 )?,
220 None => {
221 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
224 assert!(attention_mask.is_some());
226 paged_attn.forward(
227 &q,
228 &k,
229 &v,
230 attention_mask,
231 None,
232 None,
233 &input_metadata,
234 &self.sdpa_params,
235 Some(flash_params),
236 )?
237 }
238 },
239 None => {
240 let (k, v) = kv_cache.append(&k, &v)?;
241
242 Sdpa.run_attention(
243 &q,
244 &k,
245 &v,
246 attention_mask,
247 Some(flash_params),
248 &self.sdpa_params,
249 )?
250 }
251 };
252
253 if let Some(t) = self.q_proj.quantized_act_type() {
254 attn_output = attn_output.to_dtype(t)?;
255 }
256 attn_output = if attention_mask.is_some() {
257 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
258 } else {
259 attn_output.reshape((b_sz, q_len, ()))?
260 };
261 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
262 if self.q_proj.quantized_act_type().is_some() {
263 res = res.to_dtype(original_dtype)?;
264 }
265 Ok(res)
266 }
267}
268
269struct DecoderLayer {
270 self_attn: Attention,
271 mlp: Box<dyn MlpLayer>,
272 input_layernorm: RmsNorm,
273 post_attention_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277 #[allow(clippy::too_many_arguments)]
278 fn new(
279 rotary_emb: Arc<RotaryEmbedding>,
280 cfg: &Config,
281 vb: ShardedVarBuilder,
282 mapper: &dyn DeviceMapper,
283 layer_idx: usize,
284 loading_isq: bool,
285 paged_attn: Option<PagedAttention>,
286 comm: &Arc<mistralrs_quant::Comm>,
287 ) -> Result<Self> {
288 let self_attn = Attention::new(
289 rotary_emb,
290 cfg,
291 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
292 paged_attn,
293 comm,
294 )?;
295 let mlp = Mlp::new(
296 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
297 cfg.hidden_size,
298 cfg.intermediate_size,
299 &cfg.quantization_config,
300 cfg.hidden_act()?,
301 comm,
302 )?;
303 let input_layernorm = RmsNorm::new_gemma(
304 cfg.hidden_size,
305 cfg.rms_norm_eps,
306 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
307 )?;
308 let post_attention_layernorm = RmsNorm::new_gemma(
309 cfg.hidden_size,
310 cfg.rms_norm_eps,
311 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
312 )?;
313 Ok(Self {
314 self_attn,
315 mlp: Box::new(mlp),
316 input_layernorm,
317 post_attention_layernorm,
318 })
319 }
320
321 #[allow(clippy::too_many_arguments)]
322 fn forward(
323 &self,
324 xs: &Tensor,
325 attention_mask: Option<&Tensor>,
326 seqlen_offsets: &[usize],
327 kv_cache: &mut KvCache,
328 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
329 flash_params: &FlashParams,
330 ) -> Result<Tensor> {
331 let residual = xs;
332 let xs = self.input_layernorm.forward(xs)?;
333 let xs = self.self_attn.forward(
334 &xs,
335 attention_mask,
336 seqlen_offsets,
337 kv_cache,
338 metadata,
339 flash_params,
340 )?;
341 let xs = (xs + residual)?;
342 let residual = &xs;
343 let xs = self
344 .mlp
345 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
346 residual + xs
347 }
348}
349
350pub struct Model {
351 embed_tokens: candle_nn::Embedding,
352 layers: Vec<DecoderLayer>,
353 norm: RmsNorm,
354 lm_head: Arc<dyn QuantMethod>,
355 hidden_size: usize,
356 device: Device,
357 cache: EitherCache,
358 max_seq_len: usize,
359 mapper: Box<dyn DeviceMapper + Send + Sync>,
360 cfg: ModelConfigMetadata,
361}
362
363impl Model {
364 pub fn new(
365 cfg: &Config,
366 vb: ShardedVarBuilder,
367 is_gptx: bool,
368 normal_loading_metadata: NormalLoadingMetadata,
369 attention_mechanism: AttentionImplementation,
370 ) -> Result<Self> {
371 if let Some(ref quant_cfg) = &cfg.quantization_config {
372 tracing::info!(
373 "Using {} quantization: {}.",
374 quant_cfg.quant_method.to_string(),
375 quant_cfg.get_bits_name(&vb)
376 );
377 }
378 let mapper = normal_loading_metadata.mapper;
379
380 let vb_m = vb.pp("model");
381 let embed_tokens = embedding(
382 cfg.vocab_size,
383 cfg.hidden_size,
384 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
385 )?;
386 let mut ropes = HashMap::new();
387 for layer_idx in 0..cfg.num_hidden_layers {
388 let device = mapper
389 .device_for(layer_idx, false)
390 .unwrap_or(&normal_loading_metadata.real_device);
391 ropes.insert(
392 device.location(),
393 Arc::new(RotaryEmbedding::new(
394 cfg.rope_theta as f32,
395 cfg.head_dim,
396 cfg.max_position_embeddings,
397 device,
398 is_gptx,
399 vb_m.dtype(),
400 )?),
401 );
402 }
403 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
404 let vb_l = vb_m.pp("layers");
405 for layer_idx in NiceProgressBar::<_, 'b'>(
406 0..cfg.num_hidden_layers,
407 "Loading repeating layers",
408 &normal_loading_metadata.multi_progress,
409 ) {
410 let device = mapper
411 .device_for(layer_idx, false)
412 .unwrap_or(&normal_loading_metadata.real_device);
413 let rotary_emb = ropes
414 .get(&device.location())
415 .expect("No RoPE for device location!")
416 .clone();
417 let paged_attn = match &attention_mechanism {
418 AttentionImplementation::Eager => None,
419 AttentionImplementation::PagedAttention => {
420 Some(PagedAttention::new(cfg.head_dim, device, None)?)
421 }
422 };
423 let comm = mapper.get_comm_for(layer_idx)?;
424 let layer = DecoderLayer::new(
425 rotary_emb.clone(),
426 cfg,
427 vb_l.pp(layer_idx),
428 &*mapper,
429 layer_idx,
430 normal_loading_metadata.loading_isq,
431 paged_attn,
432 &comm,
433 )?;
434 layers.push(layer)
435 }
436 let norm = RmsNorm::new_gemma(
437 cfg.hidden_size,
438 cfg.rms_norm_eps,
439 mapper.set_nm_device(vb_m.pp("norm"), false),
440 )?;
441 let lm_head = mapper.cast_nm_device(
442 embed_tokens.embeddings(),
443 normal_loading_metadata.loading_isq,
444 )?;
445 Ok(Self {
446 embed_tokens,
447 layers,
448 norm,
449 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
450 Linear::new(lm_head, None),
451 ))?),
452 device: normal_loading_metadata.real_device,
453 hidden_size: cfg.hidden_size,
454 cache: EitherCache::Normal(NormalCache::new(
455 cfg.num_hidden_layers,
456 cfg.max_position_embeddings,
457 )),
458 max_seq_len: default_max_position_embeddings(),
459 cfg: ModelConfigMetadata {
460 max_seq_len: cfg.max_position_embeddings,
461 num_layers: cfg.num_hidden_layers,
462 hidden_size: cfg.hidden_size,
463 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
464 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
465 .max(1),
466 sliding_window: None,
467 k_head_dim: cfg.head_dim,
468 v_head_dim: cfg.head_dim,
469 },
470 mapper,
471 })
472 }
473
474 pub fn forward(
475 &self,
476 input_ids: &Tensor,
477 seqlen_offsets: &[usize],
478 context_lens: Vec<(usize, usize)>,
479 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
480 flash_params: &FlashParams,
481 ) -> Result<Tensor> {
482 let xs = self.embed_tokens.forward(input_ids)?;
483 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
484 let cache = &mut self.cache.normal().0;
485 let attention_mask = CausalMasker.make_causal_mask_matrix(
486 input_ids,
487 metadata
488 .as_ref()
489 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
490 .unwrap_or(cache as &dyn PastKvLenCache),
491 xs.dtype(),
492 self.cfg.num_attn_heads,
493 )?;
494 let attention_mask = attention_mask.filter(|_| {
496 metadata
497 .as_ref()
498 .map(|(_, meta)| meta.is_first_prompt_chunk)
499 .unwrap_or(true)
500 });
501 for (i, layer) in self.layers.iter().enumerate() {
502 xs = self.mapper.map(xs, i)?;
503 xs = layer.forward(
504 &xs,
505 attention_mask
506 .as_ref()
507 .map(|m| m.to_device(xs.device()).unwrap())
508 .as_ref(),
509 seqlen_offsets,
510 &mut cache[i],
511 metadata
512 .as_ref()
513 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
514 flash_params,
515 )?;
516 }
517 let xs = xs.to_device(&self.device)?;
518 let mut xs = xs.apply(&self.norm)?;
519 if let Some(t) = self.lm_head.quantized_act_type() {
520 xs = xs.to_dtype(t)?;
521 }
522 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
523 }
524}
525
526impl IsqModel for Model {
527 fn get_layers(
528 &mut self,
529 ) -> (
530 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
531 &dyn DeviceMapper,
532 ) {
533 let mut tensors = Vec::new();
534 tensors.push((&mut self.lm_head, None));
535 for (i, layer) in self.layers.iter_mut().enumerate() {
536 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
537 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
538 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
539 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
540 tensors.extend(
541 layer
542 .mlp
543 .get_isq_layers()
544 .into_iter()
545 .map(|m| (m, Some(i)))
546 .collect::<Vec<_>>(),
547 );
548 }
549 (tensors, &*self.mapper)
550 }
551
552 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
553 let uvb = UnVarBuilder::new();
554
555 let uvb_m = uvb.pp("model");
556 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
557 uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
558
559 for (layer_idx, layer) in self.layers.iter().enumerate() {
560 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
561 uvb_l
562 .pp("input_layernorm")
563 .add(&layer.input_layernorm.undo_gemma().unwrap());
564 uvb_l
565 .pp("post_attention_layernorm")
566 .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
567 }
568
569 uvb.to_safetensors()
570 }
571
572 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
573 let mut names = Vec::new();
575 names.push(None);
577 for i in 0..self.layers.len() {
578 names.push(Some(format!("blk.{i}.attn_q.weight")));
579 names.push(Some(format!("blk.{i}.attn_k.weight")));
580 names.push(Some(format!("blk.{i}.attn_v.weight")));
581 names.push(Some(format!("blk.{i}.attn_output.weight")));
582 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
583 names.push(Some(format!("blk.{i}.ffn_up.weight")));
584 names.push(Some(format!("blk.{i}.ffn_down.weight")));
585 }
586 Ok(names)
587 }
588}
589
590impl NormalModel for Model {
591 fn forward(
592 &self,
593 input_ids: &Tensor,
594 seqlen_offsets: &[usize],
595 context_lens: Vec<(usize, usize)>,
596 _position_ids: Vec<usize>,
597 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
598 flash_params: &FlashParams,
599 ) -> Result<Tensor> {
600 self.forward(
601 input_ids,
602 seqlen_offsets,
603 context_lens,
604 metadata,
605 flash_params,
606 )
607 }
608 fn xlora_forward(
609 &self,
610 _input_ids: &Tensor,
611 _input_ids_full: &Tensor,
612 _seqlen_offsets: &[usize],
613 _seqlen_offsets_full: &[usize],
614 _no_kv_cache: bool,
615 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
616 _context_lens: Vec<(usize, usize)>,
617 _position_ids: Vec<usize>,
618 _flash_params: &FlashParams,
619 _flash_params_full: &FlashParams,
620 ) -> Result<Tensor> {
621 unimplemented!()
622 }
623 fn cache(&self) -> &EitherCache {
624 &self.cache
625 }
626 fn cache_mut(&mut self) -> &mut EitherCache {
627 &mut self.cache
628 }
629 fn device(&self) -> &Device {
630 &self.device
631 }
632 fn is_xlora(&self) -> bool {
633 false
634 }
635 fn max_seq_len(&self) -> usize {
636 self.max_seq_len
637 }
638 fn config(&self) -> &ModelConfigMetadata {
639 &self.cfg
640 }
641}
642
643impl AnyMoeBaseModelMixin for Model {
644 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
645 let mut mlps = Vec::new();
646 for layer in &self.layers {
647 mlps.push(&*layer.mlp);
648 }
649 mlps
650 }
651 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
652 let mut mlps = Vec::new();
653 for layer in &mut self.layers {
654 mlps.push(&mut layer.mlp);
655 }
656 mlps
657 }
658 fn create_anymoe_layers(
659 &mut self,
660 additional_vbs: Vec<ShardedVarBuilder>,
661 config: AnyMoeConfig,
662 (prefix, mlp): (String, String),
663 mut layers: Vec<usize>,
664 expert_type: AnyMoeExpertType,
665 gate_vb: Option<ShardedVarBuilder>,
666 ) -> Result<()> {
667 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
668 if layers.is_empty() {
669 layers = (0..self.layers.len()).collect::<Vec<_>>();
670 }
671 for _ in 0..layers.len() {
672 experts.push(Vec::new());
673 }
674 for vb in additional_vbs {
675 let vb = vb.pp(&prefix);
676 for (layer, row) in experts.iter_mut().enumerate() {
677 if !layers.contains(&layer) {
678 continue;
679 }
680
681 let intermediate_size = self.layers[layer].mlp.get_params()[1];
682 let hidden_size = self.layers[layer].mlp.get_params()[0];
683 match expert_type {
684 AnyMoeExpertType::FineTuned => {
685 let (dtype, device) = self.layers[layer].mlp.dtype_device();
686 row.push(Box::new(Mlp::replicate(
687 self.layers[layer].mlp.get_params(),
688 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
689 self.layers[layer].mlp.hidden_act(),
690 &self.mapper.get_comm_for(layer)?,
691 )?));
692 }
693 AnyMoeExpertType::LoraAdapter {
694 rank,
695 alpha,
696 ref target_modules,
697 } => {
698 let vb_mlp = vb.pp(layer).pp(&mlp);
699
700 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
701 Some(get_delta_from_lora_ab!(
702 vb_mlp,
703 rank,
704 alpha,
705 (hidden_size, intermediate_size),
706 "gate_proj"
707 ))
708 } else {
709 None
710 };
711 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
712 Some(get_delta_from_lora_ab!(
713 vb_mlp,
714 rank,
715 alpha,
716 (hidden_size, intermediate_size),
717 "up_proj"
718 ))
719 } else {
720 None
721 };
722 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
723 Some(get_delta_from_lora_ab!(
724 vb_mlp,
725 rank,
726 alpha,
727 (intermediate_size, hidden_size),
728 "down_proj"
729 ))
730 } else {
731 None
732 };
733
734 row.push(self.layers[layer].mlp.new_added_delta(vec![
735 gate_proj_delta,
736 up_proj_delta,
737 down_proj_delta,
738 ])?);
739 }
740 }
741 }
742 }
743 for (layer, expert) in layers.into_iter().zip(experts) {
744 let mut experts_all = vec![self.layers[layer].mlp.clone()];
745 experts_all.extend(expert);
746 let (dtype, device) = self.layers[layer].mlp.dtype_device();
747 self.layers[layer].mlp = Box::new(MoeMlp::new(
748 experts_all,
749 config.clone(),
750 dtype,
751 &device,
752 layer,
753 gate_vb.as_ref(),
754 )?);
755 }
756 Ok(())
757 }
758 fn amoe_supported(&self) -> bool {
759 true
760 }
761}