1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Module, Result, Tensor};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7 ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24 },
25 serde_default_fn,
26 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29serde_default_fn!(bool, use_flash_attn, false);
30serde_default_fn!(bool, tie_word_embeddings, false);
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct Config {
34 pub(crate) vocab_size: usize,
35 pub(crate) hidden_size: usize,
36 pub(crate) intermediate_size: usize,
37 pub(crate) num_hidden_layers: usize,
38 pub(crate) num_attention_heads: usize,
39 pub(crate) num_key_value_heads: usize,
40 pub(crate) hidden_act: Activation,
41 pub(crate) max_position_embeddings: usize,
42 pub(crate) rms_norm_eps: f64,
43 pub(crate) rope_theta: f64,
44 pub(crate) sliding_window: Option<usize>,
45 #[serde(default = "use_flash_attn")]
46 pub(crate) use_flash_attn: bool,
47 pub(crate) head_dim: Option<usize>,
48 pub(crate) quantization_config: Option<QuantizedConfig>,
49 #[serde(default = "tie_word_embeddings")]
50 pub(crate) tie_word_embeddings: bool,
51}
52
53impl Config {
54 pub(crate) fn head_dim(&self) -> usize {
55 self.head_dim
56 .unwrap_or(self.hidden_size / self.num_attention_heads)
57 }
58}
59
60struct Attention {
61 q_proj: Arc<dyn QuantMethod>,
62 k_proj: Arc<dyn QuantMethod>,
63 v_proj: Arc<dyn QuantMethod>,
64 o_proj: Arc<dyn QuantMethod>,
65 num_heads: usize,
66 num_kv_heads: usize,
67 head_dim: usize,
68 rotary_emb: Arc<RotaryEmbedding>,
69 paged_attn: Option<PagedAttention>,
70 sdpa_params: SdpaParams,
71}
72
73impl Attention {
74 fn new(
75 rotary_emb: Arc<RotaryEmbedding>,
76 cfg: &Config,
77 vb: ShardedVarBuilder,
78 paged_attn: Option<PagedAttention>,
79 comm: &Arc<mistralrs_quant::Comm>,
80 ) -> Result<Self> {
81 let hidden_sz = cfg.hidden_size;
82 let num_heads = cfg.num_attention_heads;
83 let num_kv_heads = cfg.num_key_value_heads;
84 let head_dim = cfg.head_dim();
85 let q_proj = ColumnParallelLayer::new(
86 hidden_sz,
87 num_heads * head_dim,
88 &cfg.quantization_config,
89 false,
90 comm,
91 vb.pp("q_proj"),
92 )?;
93 let kv_shard = mistralrs_quant::compute_kv_shard(
94 cfg.num_key_value_heads,
95 cfg.hidden_size / cfg.num_attention_heads,
96 comm,
97 );
98 let k_proj = ColumnParallelLayer::new_with_shard(
99 hidden_sz,
100 num_kv_heads * head_dim,
101 &cfg.quantization_config,
102 false,
103 comm,
104 kv_shard,
105 vb.pp("k_proj"),
106 )?;
107 let v_proj = ColumnParallelLayer::new_with_shard(
108 hidden_sz,
109 num_kv_heads * head_dim,
110 &cfg.quantization_config,
111 false,
112 comm,
113 kv_shard,
114 vb.pp("v_proj"),
115 )?;
116 let o_proj = RowParallelLayer::new(
117 num_heads * head_dim,
118 hidden_sz,
119 &cfg.quantization_config,
120 false,
121 comm,
122 vb.pp("o_proj"),
123 )?;
124 Ok(Self {
125 q_proj,
126 k_proj,
127 v_proj,
128 o_proj,
129 num_heads: num_heads / comm.world_size(),
130 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131 head_dim,
132 rotary_emb,
133 paged_attn,
134 sdpa_params: SdpaParams {
135 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
136 cfg.num_key_value_heads,
137 cfg.num_attention_heads,
138 comm,
139 ),
140 use_flash_attn: cfg.use_flash_attn,
141 softcap: None,
142 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
143 sliding_window: cfg.sliding_window,
144 },
145 })
146 }
147
148 #[allow(clippy::too_many_arguments)]
149 fn forward(
150 &self,
151 xs: &Tensor,
152 attention_mask: Option<&Tensor>,
153 seqlen_offsets: &[usize],
154 kv_cache: &mut KvCache,
155 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
156 flash_params: &FlashParams,
157 ) -> Result<Tensor> {
158 let (b_sz, q_len, _) = xs.dims3()?;
159
160 let original_dtype = xs.dtype();
161 let mut xs = xs.clone();
162 if let Some(t) = self.q_proj.quantized_act_type() {
163 xs = xs.to_dtype(t)?;
164 }
165 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
166 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
167 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
168 if self.q_proj.quantized_act_type().is_some() {
169 q = q.to_dtype(original_dtype)?;
170 k = k.to_dtype(original_dtype)?;
171 v = v.to_dtype(original_dtype)?;
172 }
173
174 let (q, k, v) = if q_len != 1 {
175 let q = q
176 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
177 .transpose(1, 2)?;
178 let k = k
179 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
180 .transpose(1, 2)?;
181 let v = v
182 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
183 .transpose(1, 2)?;
184 (q, k, v)
185 } else {
186 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
187 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
188 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
189 (q, k, v)
190 };
191
192 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
193
194 let mut attn_output = match &self.paged_attn {
195 Some(paged_attn) => match metadata {
196 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
197 &q,
198 &k,
199 &v,
200 attention_mask,
201 Some(key_cache),
202 Some(value_cache),
203 input_metadata,
204 &self.sdpa_params,
205 Some(flash_params),
206 )?,
207 None => {
208 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
211 assert!(attention_mask.is_some());
213 paged_attn.forward(
214 &q,
215 &k,
216 &v,
217 attention_mask,
218 None,
219 None,
220 &input_metadata,
221 &self.sdpa_params,
222 Some(flash_params),
223 )?
224 }
225 },
226 None => {
227 let (k, v) = kv_cache.append(&k, &v)?;
228
229 Sdpa.run_attention(
230 &q,
231 &k,
232 &v,
233 attention_mask,
234 Some(flash_params),
235 &self.sdpa_params,
236 )?
237 }
238 };
239
240 if let Some(t) = self.q_proj.quantized_act_type() {
241 attn_output = attn_output.to_dtype(t)?;
242 }
243 attn_output = if attention_mask.is_some() {
244 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
245 } else {
246 attn_output.reshape((b_sz, q_len, ()))?
247 };
248 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
249 if self.q_proj.quantized_act_type().is_some() {
250 res = res.to_dtype(original_dtype)?;
251 }
252 Ok(res)
253 }
254}
255
256struct DecoderLayer {
257 self_attn: Attention,
258 mlp: Box<dyn MlpLayer>,
259 input_layernorm: RmsNorm,
260 post_attention_layernorm: RmsNorm,
261}
262
263impl DecoderLayer {
264 #[allow(clippy::too_many_arguments)]
265 fn new(
266 rotary_emb: Arc<RotaryEmbedding>,
267 cfg: &Config,
268 vb: ShardedVarBuilder,
269 mapper: &dyn DeviceMapper,
270 layer_idx: usize,
271 loading_isq: bool,
272 paged_attn: Option<PagedAttention>,
273 comm: &Arc<mistralrs_quant::Comm>,
274 ) -> Result<Self> {
275 let self_attn = Attention::new(
276 rotary_emb,
277 cfg,
278 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279 paged_attn,
280 comm,
281 )?;
282 let mlp = Mlp::new(
283 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
284 cfg.hidden_size,
285 cfg.intermediate_size,
286 &cfg.quantization_config,
287 cfg.hidden_act,
288 comm,
289 )?;
290 let input_layernorm = RmsNorm::new(
291 cfg.hidden_size,
292 cfg.rms_norm_eps,
293 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
294 )?;
295 let post_attention_layernorm = RmsNorm::new(
296 cfg.hidden_size,
297 cfg.rms_norm_eps,
298 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
299 )?;
300 Ok(Self {
301 self_attn,
302 mlp: Box::new(mlp),
303 input_layernorm,
304 post_attention_layernorm,
305 })
306 }
307
308 #[allow(clippy::too_many_arguments)]
309 fn forward(
310 &self,
311 xs: &Tensor,
312 attention_mask: Option<&Tensor>,
313 seqlen_offsets: &[usize],
314 kv_cache: &mut KvCache,
315 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
316 flash_params: &FlashParams,
317 ) -> Result<Tensor> {
318 let residual = xs;
319 let xs = self.input_layernorm.forward(xs)?;
320 let xs = self.self_attn.forward(
321 &xs,
322 attention_mask,
323 seqlen_offsets,
324 kv_cache,
325 metadata,
326 flash_params,
327 )?;
328 let xs = (xs + residual)?;
329 let residual = &xs;
330 let xs = self
331 .mlp
332 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
333 residual + xs
334 }
335}
336
337pub struct Model {
338 embed_tokens: candle_nn::Embedding,
339 layers: Vec<DecoderLayer>,
340 norm: RmsNorm,
341 lm_head: Arc<dyn QuantMethod>,
342 sliding_window: Option<usize>,
343 device: Device,
344 pub(crate) cache: EitherCache,
345 max_seq_len: usize,
346 mapper: Box<dyn DeviceMapper + Send + Sync>,
347 cfg: ModelConfigMetadata,
348}
349
350impl Model {
351 pub fn new(
352 cfg: &Config,
353 vb: ShardedVarBuilder,
354 is_gptx: bool,
355 normal_loading_metadata: NormalLoadingMetadata,
356 attention_mechanism: AttentionImplementation,
357 ) -> Result<Self> {
358 let vb_m = vb.pp("model");
359 let vb_lm_head = vb.pp("lm_head");
360 Self::new_inner(
361 cfg,
362 vb_m,
363 vb_lm_head,
364 is_gptx,
365 normal_loading_metadata,
366 attention_mechanism,
367 )
368 }
369
370 pub fn new_inner(
371 cfg: &Config,
372 vb_m: ShardedVarBuilder,
373 vb_lm_head: ShardedVarBuilder,
374 is_gptx: bool,
375 normal_loading_metadata: NormalLoadingMetadata,
376 attention_mechanism: AttentionImplementation,
377 ) -> Result<Self> {
378 if let Some(ref quant_cfg) = &cfg.quantization_config {
379 tracing::info!(
380 "Using {} quantization: {}.",
381 quant_cfg.name(),
382 quant_cfg.get_bits_name(&vb_m)
383 );
384 }
385 let mapper = normal_loading_metadata.mapper;
386
387 let embed_tokens = embedding(
388 cfg.vocab_size,
389 cfg.hidden_size,
390 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
391 &cfg.quantization_config,
392 )?;
393
394 let head_dim = cfg.head_dim();
395 let mut ropes = HashMap::new();
396 for layer_idx in 0..cfg.num_hidden_layers {
397 let device = mapper
398 .device_for(layer_idx, false)
399 .unwrap_or(&normal_loading_metadata.real_device);
400 ropes.insert(
401 device.location(),
402 Arc::new(RotaryEmbedding::new(
403 cfg.rope_theta as f32,
404 head_dim,
405 cfg.max_position_embeddings,
406 device,
407 is_gptx,
408 vb_m.dtype(),
409 )?),
410 );
411 }
412
413 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
414 let vb_l = vb_m.pp("layers");
415 for layer_idx in NiceProgressBar::<_, 'b'>(
416 0..cfg.num_hidden_layers,
417 "Loading repeating layers",
418 &normal_loading_metadata.multi_progress,
419 ) {
420 let device = mapper
421 .device_for(layer_idx, false)
422 .unwrap_or(&normal_loading_metadata.real_device);
423 let rotary_emb = ropes
424 .get(&device.location())
425 .expect("No RoPE for device location!")
426 .clone();
427 let paged_attn = match &attention_mechanism {
428 AttentionImplementation::Eager => None,
429 AttentionImplementation::PagedAttention => {
430 Some(PagedAttention::new(head_dim, device, None)?)
431 }
432 };
433 let comm = mapper.get_comm_for(layer_idx)?;
434 let layer = DecoderLayer::new(
435 rotary_emb.clone(),
436 cfg,
437 vb_l.pp(layer_idx),
438 &*mapper,
439 layer_idx,
440 normal_loading_metadata.loading_isq,
441 paged_attn,
442 &comm,
443 )?;
444 layers.push(layer)
445 }
446 let norm = RmsNorm::new(
447 cfg.hidden_size,
448 cfg.rms_norm_eps,
449 mapper.set_nm_device(vb_m.pp("norm"), false),
450 )?;
451 let lm_head = if !cfg.tie_word_embeddings {
452 ReplicatedLayer::new(
453 cfg.hidden_size,
454 cfg.vocab_size,
455 &None,
456 false,
457 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
458 )?
459 } else {
460 ReplicatedLayer::from_linear(candle_nn::Linear::new(
461 mapper.cast_nm_device(
462 embed_tokens.embeddings(),
463 normal_loading_metadata.loading_isq,
464 )?,
465 None,
466 ))?
467 };
468 Ok(Self {
469 embed_tokens,
470 layers,
471 norm,
472 lm_head,
473 sliding_window: cfg.sliding_window,
474 device: normal_loading_metadata.real_device,
475 cache: EitherCache::Normal(NormalCache::new_sliding(
476 cfg.num_hidden_layers,
477 cfg.max_position_embeddings,
478 cfg.sliding_window,
479 )),
480 max_seq_len: cfg.max_position_embeddings,
481 cfg: ModelConfigMetadata {
482 max_seq_len: cfg.max_position_embeddings,
483 num_layers: cfg.num_hidden_layers,
484 hidden_size: cfg.hidden_size,
485 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
486 .max(1),
487 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
488 sliding_window: cfg.sliding_window,
489 k_head_dim: cfg.head_dim(),
490 v_head_dim: cfg.head_dim(),
491 },
492 mapper,
493 })
494 }
495
496 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
497 self.embed_tokens.forward(input_ids)
498 }
499
500 pub fn forward(
501 &self,
502 input_ids: &Tensor,
503 seqlen_offsets: &[usize],
504 context_lens: Vec<(usize, usize)>,
505 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
506 flash_params: &FlashParams,
507 ) -> Result<Tensor> {
508 self.forward_embeds(
509 input_ids,
510 self.embed_tokens.forward(input_ids)?,
511 seqlen_offsets,
512 context_lens,
513 metadata,
514 flash_params,
515 )
516 }
517
518 #[allow(clippy::too_many_arguments)]
519 pub fn forward_embeds(
520 &self,
521 input_ids: &Tensor,
522 input_embeds: Tensor,
523 seqlen_offsets: &[usize],
524 context_lens: Vec<(usize, usize)>,
525 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
526 flash_params: &FlashParams,
527 ) -> Result<Tensor> {
528 let mut xs = input_embeds;
529 let cache = &mut self.cache.normal().0;
530 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
531 input_ids,
532 metadata
533 .as_ref()
534 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
535 .unwrap_or(cache as &dyn PastKvLenCache),
536 self.sliding_window,
537 xs.dtype(),
538 self.cfg.num_attn_heads,
539 )?;
540 let attention_mask = attention_mask.filter(|_| {
542 metadata
543 .as_ref()
544 .map(|(_, meta)| meta.is_first_prompt_chunk)
545 .unwrap_or(true)
546 });
547 for (i, layer) in self.layers.iter().enumerate() {
548 xs = self.mapper.map(xs, i)?;
549 xs = layer.forward(
550 &xs,
551 attention_mask
552 .as_ref()
553 .map(|m| m.to_device(xs.device()).unwrap())
554 .as_ref(),
555 seqlen_offsets,
556 &mut cache[i],
557 metadata
558 .as_ref()
559 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
560 flash_params,
561 )?;
562 }
563 let xs = xs.to_device(&self.device)?;
564 let mut xs = xs.apply(&self.norm)?;
565 if let Some(t) = self.lm_head.quantized_act_type() {
566 xs = xs.to_dtype(t)?;
567 }
568 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
569 }
570}
571
572impl IsqModel for Model {
573 fn get_layers(
574 &mut self,
575 ) -> (
576 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
577 &dyn DeviceMapper,
578 ) {
579 let mut tensors = Vec::new();
580 tensors.push((&mut self.lm_head, None));
581 for (i, layer) in self.layers.iter_mut().enumerate() {
582 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
583 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
584 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
585 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
586 tensors.extend(
587 layer
588 .mlp
589 .get_isq_layers()
590 .into_iter()
591 .map(|m| (m, Some(i)))
592 .collect::<Vec<_>>(),
593 );
594 }
595 (tensors, &*self.mapper)
596 }
597
598 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
599 let uvb = UnVarBuilder::new();
600
601 let uvb_m = uvb.pp("model");
602 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
603 uvb_m.pp("norm").add(&self.norm);
604
605 for (layer_idx, layer) in self.layers.iter().enumerate() {
606 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
607 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
608 uvb_l
609 .pp("post_attention_layernorm")
610 .add(&layer.post_attention_layernorm);
611 }
612
613 uvb.to_safetensors()
614 }
615
616 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
617 let mut names = Vec::new();
619 names.push(None);
621 for i in 0..self.layers.len() {
622 names.push(Some(format!("blk.{i}.attn_q.weight")));
623 names.push(Some(format!("blk.{i}.attn_k.weight")));
624 names.push(Some(format!("blk.{i}.attn_v.weight")));
625 names.push(Some(format!("blk.{i}.attn_output.weight")));
626 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
627 names.push(Some(format!("blk.{i}.ffn_up.weight")));
628 names.push(Some(format!("blk.{i}.ffn_down.weight")));
629 }
630 Ok(names)
631 }
632}
633
634impl NormalModel for Model {
635 fn forward(
636 &self,
637 input_ids: &Tensor,
638 seqlen_offsets: &[usize],
639 context_lens: Vec<(usize, usize)>,
640 _position_ids: Vec<usize>,
641 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
642 flash_params: &FlashParams,
643 ) -> Result<Tensor> {
644 self.forward(
645 input_ids,
646 seqlen_offsets,
647 context_lens,
648 metadata,
649 flash_params,
650 )
651 }
652 fn xlora_forward(
653 &self,
654 _input_ids: &Tensor,
655 _input_ids_full: &Tensor,
656 _seqlen_offsets: &[usize],
657 _seqlen_offsets_full: &[usize],
658 _no_kv_cache: bool,
659 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
660 _context_lens: Vec<(usize, usize)>,
661 _position_ids: Vec<usize>,
662 _flash_params: &FlashParams,
663 _flash_params_full: &FlashParams,
664 ) -> Result<Tensor> {
665 unimplemented!()
666 }
667 fn cache(&self) -> &EitherCache {
668 &self.cache
669 }
670 fn cache_mut(&mut self) -> &mut EitherCache {
671 &mut self.cache
672 }
673 fn device(&self) -> &Device {
674 &self.device
675 }
676 fn is_xlora(&self) -> bool {
677 false
678 }
679 fn max_seq_len(&self) -> usize {
680 self.max_seq_len
681 }
682 fn config(&self) -> &ModelConfigMetadata {
683 &self.cfg
684 }
685}
686
687impl AnyMoeBaseModelMixin for Model {
688 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
689 let mut mlps = Vec::new();
690 for layer in &self.layers {
691 mlps.push(&*layer.mlp);
692 }
693 mlps
694 }
695 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
696 let mut mlps = Vec::new();
697 for layer in &mut self.layers {
698 mlps.push(&mut layer.mlp);
699 }
700 mlps
701 }
702 fn create_anymoe_layers(
703 &mut self,
704 additional_vbs: Vec<ShardedVarBuilder>,
705 config: AnyMoeConfig,
706 (prefix, mlp): (String, String),
707 mut layers: Vec<usize>,
708 expert_type: AnyMoeExpertType,
709 gate_vb: Option<ShardedVarBuilder>,
710 ) -> Result<()> {
711 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
712 if layers.is_empty() {
713 layers = (0..self.layers.len()).collect::<Vec<_>>();
714 }
715 for _ in 0..layers.len() {
716 experts.push(Vec::new());
717 }
718 for vb in additional_vbs {
719 let vb = vb.pp(&prefix);
720 for (layer, row) in experts.iter_mut().enumerate() {
721 if !layers.contains(&layer) {
722 continue;
723 }
724
725 let intermediate_size = self.layers[layer].mlp.get_params()[1];
726 let hidden_size = self.layers[layer].mlp.get_params()[0];
727 match expert_type {
728 AnyMoeExpertType::FineTuned => {
729 let (dtype, device) = self.layers[layer].mlp.dtype_device();
730 row.push(Box::new(Mlp::replicate(
731 self.layers[layer].mlp.get_params(),
732 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
733 self.layers[layer].mlp.hidden_act(),
734 &self.mapper.get_comm_for(layer)?,
735 )?));
736 }
737 AnyMoeExpertType::LoraAdapter {
738 rank,
739 alpha,
740 ref target_modules,
741 } => {
742 let vb_mlp = vb.pp(layer).pp(&mlp);
743
744 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
745 Some(get_delta_from_lora_ab!(
746 vb_mlp,
747 rank,
748 alpha,
749 (hidden_size, intermediate_size),
750 "gate_proj"
751 ))
752 } else {
753 None
754 };
755 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
756 Some(get_delta_from_lora_ab!(
757 vb_mlp,
758 rank,
759 alpha,
760 (hidden_size, intermediate_size),
761 "up_proj"
762 ))
763 } else {
764 None
765 };
766 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
767 Some(get_delta_from_lora_ab!(
768 vb_mlp,
769 rank,
770 alpha,
771 (intermediate_size, hidden_size),
772 "down_proj"
773 ))
774 } else {
775 None
776 };
777
778 row.push(self.layers[layer].mlp.new_added_delta(vec![
779 gate_proj_delta,
780 up_proj_delta,
781 down_proj_delta,
782 ])?);
783 }
784 }
785 }
786 }
787 for (layer, expert) in layers.into_iter().zip(experts) {
788 let mut experts_all = vec![self.layers[layer].mlp.clone()];
789 experts_all.extend(expert);
790 let (dtype, device) = self.layers[layer].mlp.dtype_device();
791 self.layers[layer].mlp = Box::new(MoeMlp::new(
792 experts_all,
793 config.clone(),
794 dtype,
795 &device,
796 layer,
797 gate_vb.as_ref(),
798 )?);
799 }
800 Ok(())
801 }
802 fn amoe_supported(&self) -> bool {
803 true
804 }
805}