1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use mistralrs_quant::{
5 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
6 ShardedVarBuilder,
7};
8use std::{collections::HashMap, sync::Arc};
9
10use crate::{
11 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 get_delta_from_lora_ab,
15 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
16 layers_masker::PastKvLenCache,
17 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
22 },
23 serde_default_fn,
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27serde_default_fn!(bool, word_emb_default, false);
28serde_default_fn!(bool, use_flash_attn, false);
29
30#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)]
31pub struct Config {
32 pub vocab_size: usize,
33 pub hidden_size: usize,
34 pub intermediate_size: usize,
35 pub num_hidden_layers: usize,
36 pub num_attention_heads: usize,
37 pub num_key_value_heads: usize,
38 pub max_position_embeddings: usize,
39 pub sliding_window: usize,
40 pub rope_theta: f64,
41 pub rms_norm_eps: f64,
42 pub hidden_act: Activation,
43 #[serde(default = "use_flash_attn")]
44 pub use_flash_attn: bool,
45 pub quantization_config: Option<QuantizedConfig>,
46 #[serde(default = "word_emb_default")]
47 pub tie_word_embeddings: bool,
48}
49
50struct Attention {
51 q_proj: Arc<dyn QuantMethod>,
52 k_proj: Arc<dyn QuantMethod>,
53 v_proj: Arc<dyn QuantMethod>,
54 o_proj: Arc<dyn QuantMethod>,
55 num_heads: usize,
56 num_kv_heads: usize,
57 head_dim: usize,
58 rotary_emb: Arc<RotaryEmbedding>,
59 paged_attn: Option<PagedAttention>,
60 sdpa_params: SdpaParams,
61}
62
63impl Attention {
64 fn new(
65 rotary_emb: Arc<RotaryEmbedding>,
66 cfg: &Config,
67 vb: ShardedVarBuilder,
68 paged_attn: Option<PagedAttention>,
69 comm: &Arc<mistralrs_quant::Comm>,
70 ) -> Result<Self> {
71 let hidden_sz = cfg.hidden_size;
72 let num_heads = cfg.num_attention_heads;
73 let num_kv_heads = cfg.num_key_value_heads;
74 let head_dim = hidden_sz / num_heads;
75 let q_proj = ColumnParallelLayer::new(
76 hidden_sz,
77 num_heads * head_dim,
78 &cfg.quantization_config,
79 true,
80 comm,
81 vb.pp("q_proj"),
82 )?;
83 let kv_shard = mistralrs_quant::compute_kv_shard(
84 cfg.num_key_value_heads,
85 cfg.hidden_size / cfg.num_attention_heads,
86 comm,
87 );
88 let k_proj = ColumnParallelLayer::new_with_shard(
89 hidden_sz,
90 num_kv_heads * head_dim,
91 &cfg.quantization_config,
92 true,
93 comm,
94 kv_shard,
95 vb.pp("k_proj"),
96 )?;
97 let v_proj = ColumnParallelLayer::new_with_shard(
98 hidden_sz,
99 num_kv_heads * head_dim,
100 &cfg.quantization_config,
101 true,
102 comm,
103 kv_shard,
104 vb.pp("v_proj"),
105 )?;
106 let o_proj = RowParallelLayer::new(
107 num_heads * head_dim,
108 hidden_sz,
109 &cfg.quantization_config,
110 false,
111 comm,
112 vb.pp("o_proj"),
113 )?;
114 Ok(Self {
115 q_proj,
116 k_proj,
117 v_proj,
118 o_proj,
119 num_heads: num_heads / comm.world_size(),
120 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
121 head_dim,
122 rotary_emb,
123 paged_attn,
124 sdpa_params: SdpaParams {
125 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
126 cfg.num_key_value_heads,
127 cfg.num_attention_heads,
128 comm,
129 ),
130 use_flash_attn: cfg.use_flash_attn,
131 softcap: None,
132 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
133 sliding_window: None,
134 },
135 })
136 }
137
138 #[allow(clippy::too_many_arguments)]
139 fn forward(
140 &self,
141 xs: &Tensor,
142 attention_mask: Option<&Tensor>,
143 seqlen_offsets: &[usize],
144 kv_cache: &mut KvCache,
145 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
146 flash_params: &FlashParams,
147 ) -> Result<Tensor> {
148 let (b_sz, q_len, _) = xs.dims3()?;
149
150 let original_dtype = xs.dtype();
151 let mut xs = xs.clone();
152 if let Some(t) = self.q_proj.quantized_act_type() {
153 xs = xs.to_dtype(t)?;
154 }
155 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
156 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
157 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
158 if self.q_proj.quantized_act_type().is_some() {
159 q = q.to_dtype(original_dtype)?;
160 k = k.to_dtype(original_dtype)?;
161 v = v.to_dtype(original_dtype)?;
162 }
163
164 let (q, k, v) = if q_len != 1 {
165 let q = q
166 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
167 .transpose(1, 2)?;
168 let k = k
169 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
170 .transpose(1, 2)?;
171 let v = v
172 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173 .transpose(1, 2)?;
174 (q, k, v)
175 } else {
176 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
177 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179 (q, k, v)
180 };
181
182 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
183
184 let mut attn_output = match &self.paged_attn {
185 Some(paged_attn) => match metadata {
186 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
187 &q,
188 &k,
189 &v,
190 attention_mask,
191 Some(key_cache),
192 Some(value_cache),
193 input_metadata,
194 &self.sdpa_params,
195 Some(flash_params),
196 )?,
197 None => {
198 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
201 assert!(attention_mask.is_some());
203 paged_attn.forward(
204 &q,
205 &k,
206 &v,
207 attention_mask,
208 None,
209 None,
210 &input_metadata,
211 &self.sdpa_params,
212 Some(flash_params),
213 )?
214 }
215 },
216 None => {
217 let (k, v) = kv_cache.append(&k, &v)?;
218
219 Sdpa.run_attention(
220 &q,
221 &k,
222 &v,
223 attention_mask,
224 Some(flash_params),
225 &self.sdpa_params,
226 )?
227 }
228 };
229
230 if let Some(t) = self.q_proj.quantized_act_type() {
231 attn_output = attn_output.to_dtype(t)?;
232 }
233 attn_output = if attention_mask.is_some() {
234 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
235 } else {
236 attn_output.reshape((b_sz, q_len, ()))?
237 };
238 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
239 if self.q_proj.quantized_act_type().is_some() {
240 res = res.to_dtype(original_dtype)?;
241 }
242 Ok(res)
243 }
244}
245
246struct DecoderLayer {
247 self_attn: Attention,
248 mlp: Box<dyn MlpLayer>,
249 input_layernorm: RmsNorm,
250 post_attention_layernorm: RmsNorm,
251}
252
253impl DecoderLayer {
254 #[allow(clippy::too_many_arguments)]
255 fn new(
256 rotary_emb: Arc<RotaryEmbedding>,
257 cfg: &Config,
258 vb: ShardedVarBuilder,
259 mapper: &dyn DeviceMapper,
260 layer_idx: usize,
261 loading_isq: bool,
262 paged_attn: Option<PagedAttention>,
263 comm: &Arc<mistralrs_quant::Comm>,
264 ) -> Result<Self> {
265 let self_attn = Attention::new(
266 rotary_emb,
267 cfg,
268 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
269 paged_attn,
270 comm,
271 )?;
272 let mlp = Mlp::new(
273 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
274 cfg.hidden_size,
275 cfg.intermediate_size,
276 &cfg.quantization_config,
277 cfg.hidden_act,
278 comm,
279 )?;
280 let input_layernorm = RmsNorm::new(
281 cfg.hidden_size,
282 cfg.rms_norm_eps,
283 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
284 )?;
285 let post_attention_layernorm = RmsNorm::new(
286 cfg.hidden_size,
287 cfg.rms_norm_eps,
288 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
289 )?;
290 Ok(Self {
291 self_attn,
292 mlp: Box::new(mlp),
293 input_layernorm,
294 post_attention_layernorm,
295 })
296 }
297
298 #[allow(clippy::too_many_arguments)]
299 fn forward(
300 &self,
301 xs: &Tensor,
302 attention_mask: Option<&Tensor>,
303 seqlen_offsets: &[usize],
304 kv_cache: &mut KvCache,
305 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
306 flash_params: &FlashParams,
307 ) -> Result<Tensor> {
308 let residual = xs;
309 let xs = self.input_layernorm.forward(xs)?;
310 let xs = self.self_attn.forward(
311 &xs,
312 attention_mask,
313 seqlen_offsets,
314 kv_cache,
315 metadata,
316 flash_params,
317 )?;
318 let xs = (xs + residual)?;
319 let residual = &xs;
320 let xs = self
321 .mlp
322 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
323 residual + xs
324 }
325}
326
327pub struct Model {
328 embed_tokens: candle_nn::Embedding,
329 layers: Vec<DecoderLayer>,
330 norm: RmsNorm,
331 lm_head: Arc<dyn QuantMethod>,
332 sliding_window: usize,
333 device: Device,
334 cache: EitherCache,
335 max_seq_len: usize,
336 mapper: Box<dyn DeviceMapper + Send + Sync>,
337 cfg: ModelConfigMetadata,
338}
339
340impl Model {
341 pub fn new(
342 cfg: &Config,
343 vb: ShardedVarBuilder,
344 is_gptx: bool,
345 normal_loading_metadata: NormalLoadingMetadata,
346 attention_mechanism: AttentionImplementation,
347 ) -> Result<Self> {
348 if let Some(ref quant_cfg) = &cfg.quantization_config {
349 tracing::info!(
350 "Using {} quantization: {}.",
351 quant_cfg.quant_method.to_string(),
352 quant_cfg.get_bits_name(&vb)
353 );
354 }
355 let mapper = normal_loading_metadata.mapper;
356 let vb_m = vb.pp("model");
357
358 let embed_tokens = embedding(
359 cfg.vocab_size,
360 cfg.hidden_size,
361 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
362 )?;
363 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
364 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
365
366 let mut ropes = HashMap::new();
367 for layer_idx in 0..cfg.num_hidden_layers {
368 let device = mapper
369 .device_for(layer_idx, false)
370 .unwrap_or(&normal_loading_metadata.real_device);
371 ropes.insert(
372 device.location(),
373 Arc::new(RotaryEmbedding::new(
374 cfg.rope_theta as f32,
375 head_dim,
376 cfg.max_position_embeddings,
377 device,
378 is_gptx,
379 vb_m.dtype(),
380 )?),
381 );
382 }
383
384 let vb_l = vb_m.pp("layers");
385 for layer_idx in NiceProgressBar::<_, 'b'>(
386 0..cfg.num_hidden_layers,
387 "Loading repeating layers",
388 &normal_loading_metadata.multi_progress,
389 ) {
390 let device = mapper
391 .device_for(layer_idx, false)
392 .unwrap_or(&normal_loading_metadata.real_device);
393 let rotary_emb = ropes
394 .get(&device.location())
395 .expect("No RoPE for device location!")
396 .clone();
397 let paged_attn = match &attention_mechanism {
398 AttentionImplementation::Eager => None,
399 AttentionImplementation::PagedAttention => {
400 Some(PagedAttention::new(head_dim, device, None)?)
401 }
402 };
403 let comm = mapper.get_comm_for(layer_idx)?;
404 let layer = DecoderLayer::new(
405 rotary_emb.clone(),
406 cfg,
407 vb_l.pp(layer_idx),
408 &*mapper,
409 layer_idx,
410 normal_loading_metadata.loading_isq,
411 paged_attn,
412 &comm,
413 )?;
414 layers.push(layer)
415 }
416 let norm = RmsNorm::new(
417 cfg.hidden_size,
418 cfg.rms_norm_eps,
419 mapper.set_nm_device(vb_m.pp("norm"), false),
420 )?;
421 let lm_head = if !cfg.tie_word_embeddings {
422 ReplicatedLayer::new(
423 cfg.hidden_size,
424 cfg.vocab_size,
425 &None,
426 false,
427 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
428 )?
429 } else {
430 ReplicatedLayer::from_linear(candle_nn::Linear::new(
431 mapper.cast_nm_device(
432 embed_tokens.embeddings(),
433 normal_loading_metadata.loading_isq,
434 )?,
435 None,
436 ))?
437 };
438 Ok(Self {
439 embed_tokens,
440 layers,
441 norm,
442 lm_head,
443 sliding_window: cfg.sliding_window,
444 device: normal_loading_metadata.real_device,
445 cache: EitherCache::Normal(NormalCache::new(
446 cfg.num_hidden_layers,
447 cfg.max_position_embeddings,
448 )),
449 max_seq_len: cfg.max_position_embeddings,
450 cfg: ModelConfigMetadata {
451 max_seq_len: cfg.max_position_embeddings,
452 num_layers: cfg.num_hidden_layers,
453 hidden_size: cfg.hidden_size,
454 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
455 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
456 .max(1),
457 sliding_window: Some(cfg.sliding_window),
458 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
459 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
460 },
461 mapper,
462 })
463 }
464
465 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
466 self.embed_tokens.forward(input_ids)
467 }
468
469 pub fn forward(
470 &self,
471 input_ids: &Tensor,
472 seqlen_offsets: &[usize],
473 context_lens: Vec<(usize, usize)>,
474 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
475 flash_params: &FlashParams,
476 ) -> Result<Tensor> {
477 let xs = self.embed_tokens.forward(input_ids)?;
478 self.forward_embed(
479 input_ids,
480 xs,
481 seqlen_offsets,
482 context_lens,
483 metadata,
484 flash_params,
485 )
486 }
487
488 #[allow(clippy::too_many_arguments)]
489 pub fn forward_embed(
490 &self,
491 input_ids: &Tensor,
492 mut xs: Tensor,
493 seqlen_offsets: &[usize],
494 context_lens: Vec<(usize, usize)>,
495 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
496 flash_params: &FlashParams,
497 ) -> Result<Tensor> {
498 let cache = &mut self.cache.normal().0;
499 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
500 input_ids,
501 metadata
502 .as_ref()
503 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
504 .unwrap_or(cache as &dyn PastKvLenCache),
505 Some(self.sliding_window),
506 xs.dtype(),
507 self.cfg.num_attn_heads,
508 )?;
509 let attention_mask = attention_mask.filter(|_| {
510 metadata
511 .as_ref()
512 .map(|(_, meta)| meta.is_first_prompt_chunk)
513 .unwrap_or(true)
514 });
515 for (i, layer) in self.layers.iter().enumerate() {
516 xs = self.mapper.map(xs, i)?;
517 xs = layer.forward(
518 &xs,
519 attention_mask
520 .as_ref()
521 .map(|m| m.to_device(xs.device()).unwrap())
522 .as_ref(),
523 seqlen_offsets,
524 &mut cache[i],
525 metadata
526 .as_ref()
527 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
528 flash_params,
529 )?
530 }
531 let xs = xs.to_device(&self.device)?;
532 let mut xs = xs.apply(&self.norm)?;
533 if let Some(t) = self.lm_head.quantized_act_type() {
534 xs = xs.to_dtype(t)?;
535 }
536 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
537 }
538
539 pub fn embed_dtype(&self) -> DType {
540 self.embed_tokens.embeddings().dtype()
541 }
542}
543
544impl IsqModel for Model {
545 fn get_layers(
546 &mut self,
547 ) -> (
548 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
549 &dyn DeviceMapper,
550 ) {
551 let mut tensors = Vec::new();
552 tensors.push((&mut self.lm_head, None));
553 for (i, layer) in self.layers.iter_mut().enumerate() {
554 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
555 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
556 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
557 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
558 tensors.extend(
559 layer
560 .mlp
561 .get_isq_layers()
562 .into_iter()
563 .map(|m| (m, Some(i)))
564 .collect::<Vec<_>>(),
565 );
566 }
567 (tensors, &*self.mapper)
568 }
569
570 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
571 let uvb = UnVarBuilder::new();
572
573 let uvb_m = uvb.pp("model");
574 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
575 uvb_m.pp("norm").add(&self.norm);
576
577 for (layer_idx, layer) in self.layers.iter().enumerate() {
578 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
579 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
580 uvb_l
581 .pp("post_attention_layernorm")
582 .add(&layer.post_attention_layernorm);
583 }
584
585 uvb.to_safetensors()
586 }
587
588 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
589 let mut names = Vec::new();
591 names.push(None);
593 for i in 0..self.layers.len() {
594 names.push(Some(format!("blk.{i}.attn_q.weight")));
595 names.push(Some(format!("blk.{i}.attn_k.weight")));
596 names.push(Some(format!("blk.{i}.attn_v.weight")));
597 names.push(Some(format!("blk.{i}.attn_output.weight")));
598 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
599 names.push(Some(format!("blk.{i}.ffn_up.weight")));
600 names.push(Some(format!("blk.{i}.ffn_down.weight")));
601 }
602 Ok(names)
603 }
604}
605
606impl NormalModel for Model {
607 fn forward(
608 &self,
609 input_ids: &Tensor,
610 seqlen_offsets: &[usize],
611 context_lens: Vec<(usize, usize)>,
612 _position_ids: Vec<usize>,
613 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
614 flash_params: &FlashParams,
615 ) -> Result<Tensor> {
616 self.forward(
617 input_ids,
618 seqlen_offsets,
619 context_lens,
620 metadata,
621 flash_params,
622 )
623 }
624 fn xlora_forward(
625 &self,
626 _input_ids: &Tensor,
627 _input_ids_full: &Tensor,
628 _seqlen_offsets: &[usize],
629 _seqlen_offsets_full: &[usize],
630 _no_kv_cache: bool,
631 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
632 _context_lens: Vec<(usize, usize)>,
633 _position_ids: Vec<usize>,
634 _flash_params: &FlashParams,
635 _flash_params_full: &FlashParams,
636 ) -> Result<Tensor> {
637 unimplemented!()
638 }
639 fn cache(&self) -> &EitherCache {
640 &self.cache
641 }
642 fn cache_mut(&mut self) -> &mut EitherCache {
643 &mut self.cache
644 }
645 fn device(&self) -> &Device {
646 &self.device
647 }
648 fn is_xlora(&self) -> bool {
649 false
650 }
651 fn max_seq_len(&self) -> usize {
652 self.max_seq_len
653 }
654 fn config(&self) -> &ModelConfigMetadata {
655 &self.cfg
656 }
657}
658
659impl AnyMoeBaseModelMixin for Model {
660 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
661 let mut mlps = Vec::new();
662 for layer in &self.layers {
663 mlps.push(&*layer.mlp);
664 }
665 mlps
666 }
667 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
668 let mut mlps = Vec::new();
669 for layer in &mut self.layers {
670 mlps.push(&mut layer.mlp);
671 }
672 mlps
673 }
674 fn create_anymoe_layers(
675 &mut self,
676 additional_vbs: Vec<ShardedVarBuilder>,
677 config: AnyMoeConfig,
678 (prefix, mlp): (String, String),
679 mut layers: Vec<usize>,
680 expert_type: AnyMoeExpertType,
681 gate_vb: Option<ShardedVarBuilder>,
682 ) -> Result<()> {
683 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
684 if layers.is_empty() {
685 layers = (0..self.layers.len()).collect::<Vec<_>>();
686 }
687 for _ in 0..layers.len() {
688 experts.push(Vec::new());
689 }
690 for vb in additional_vbs {
691 let vb = vb.pp(&prefix);
692 for (layer, row) in experts.iter_mut().enumerate() {
693 if !layers.contains(&layer) {
694 continue;
695 }
696
697 let intermediate_size = self.layers[layer].mlp.get_params()[1];
698 let hidden_size = self.layers[layer].mlp.get_params()[0];
699 match expert_type {
700 AnyMoeExpertType::FineTuned => {
701 let (dtype, device) = self.layers[layer].mlp.dtype_device();
702 row.push(Box::new(Mlp::replicate(
703 self.layers[layer].mlp.get_params(),
704 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
705 self.layers[layer].mlp.hidden_act(),
706 &self.mapper.get_comm_for(layer)?,
707 )?));
708 }
709 AnyMoeExpertType::LoraAdapter {
710 rank,
711 alpha,
712 ref target_modules,
713 } => {
714 let vb_mlp = vb.pp(layer).pp(&mlp);
715
716 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
717 Some(get_delta_from_lora_ab!(
718 vb_mlp,
719 rank,
720 alpha,
721 (hidden_size, intermediate_size),
722 "gate_proj"
723 ))
724 } else {
725 None
726 };
727 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
728 Some(get_delta_from_lora_ab!(
729 vb_mlp,
730 rank,
731 alpha,
732 (hidden_size, intermediate_size),
733 "up_proj"
734 ))
735 } else {
736 None
737 };
738 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
739 Some(get_delta_from_lora_ab!(
740 vb_mlp,
741 rank,
742 alpha,
743 (intermediate_size, hidden_size),
744 "down_proj"
745 ))
746 } else {
747 None
748 };
749
750 row.push(self.layers[layer].mlp.new_added_delta(vec![
751 gate_proj_delta,
752 up_proj_delta,
753 down_proj_delta,
754 ])?);
755 }
756 }
757 }
758 }
759 for (layer, expert) in layers.into_iter().zip(experts) {
760 let mut experts_all = vec![self.layers[layer].mlp.clone()];
761 experts_all.extend(expert);
762 let (dtype, device) = self.layers[layer].mlp.dtype_device();
763 self.layers[layer].mlp = Box::new(MoeMlp::new(
764 experts_all,
765 config.clone(),
766 dtype,
767 &device,
768 layer,
769 gate_vb.as_ref(),
770 )?);
771 }
772 Ok(())
773 }
774 fn amoe_supported(&self) -> bool {
775 true
776 }
777}