1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Module, Result, Tensor};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7 ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24 },
25 serde_default_fn,
26 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29serde_default_fn!(bool, use_flash_attn, false);
30serde_default_fn!(bool, tie_word_embeddings, false);
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct Config {
34 pub(crate) vocab_size: usize,
35 pub(crate) hidden_size: usize,
36 pub(crate) intermediate_size: usize,
37 pub(crate) num_hidden_layers: usize,
38 pub(crate) num_attention_heads: usize,
39 pub(crate) num_key_value_heads: usize,
40 pub(crate) hidden_act: Activation,
41 pub(crate) max_position_embeddings: usize,
42 pub(crate) rms_norm_eps: f64,
43 pub(crate) rope_theta: f64,
44 pub(crate) sliding_window: Option<usize>,
45 #[serde(default = "use_flash_attn")]
46 pub(crate) use_flash_attn: bool,
47 pub(crate) head_dim: Option<usize>,
48 pub(crate) quantization_config: Option<QuantizedConfig>,
49 #[serde(default = "tie_word_embeddings")]
50 pub(crate) tie_word_embeddings: bool,
51}
52
53impl Config {
54 pub(crate) fn head_dim(&self) -> usize {
55 self.head_dim
56 .unwrap_or(self.hidden_size / self.num_attention_heads)
57 }
58}
59
60struct Attention {
61 q_proj: Arc<dyn QuantMethod>,
62 k_proj: Arc<dyn QuantMethod>,
63 v_proj: Arc<dyn QuantMethod>,
64 o_proj: Arc<dyn QuantMethod>,
65 num_heads: usize,
66 num_kv_heads: usize,
67 head_dim: usize,
68 rotary_emb: Arc<RotaryEmbedding>,
69 paged_attn: Option<PagedAttention>,
70 sdpa_params: SdpaParams,
71}
72
73impl Attention {
74 fn new(
75 rotary_emb: Arc<RotaryEmbedding>,
76 cfg: &Config,
77 vb: ShardedVarBuilder,
78 paged_attn: Option<PagedAttention>,
79 comm: &Arc<mistralrs_quant::Comm>,
80 ) -> Result<Self> {
81 let hidden_sz = cfg.hidden_size;
82 let num_heads = cfg.num_attention_heads;
83 let num_kv_heads = cfg.num_key_value_heads;
84 let head_dim = cfg.head_dim();
85 let q_proj = ColumnParallelLayer::new(
86 hidden_sz,
87 num_heads * head_dim,
88 &cfg.quantization_config,
89 false,
90 comm,
91 vb.pp("q_proj"),
92 )?;
93 let kv_shard = mistralrs_quant::compute_kv_shard(
94 cfg.num_key_value_heads,
95 cfg.hidden_size / cfg.num_attention_heads,
96 comm,
97 );
98 let k_proj = ColumnParallelLayer::new_with_shard(
99 hidden_sz,
100 num_kv_heads * head_dim,
101 &cfg.quantization_config,
102 false,
103 comm,
104 kv_shard,
105 vb.pp("k_proj"),
106 )?;
107 let v_proj = ColumnParallelLayer::new_with_shard(
108 hidden_sz,
109 num_kv_heads * head_dim,
110 &cfg.quantization_config,
111 false,
112 comm,
113 kv_shard,
114 vb.pp("v_proj"),
115 )?;
116 let o_proj = RowParallelLayer::new(
117 num_heads * head_dim,
118 hidden_sz,
119 &cfg.quantization_config,
120 false,
121 comm,
122 vb.pp("o_proj"),
123 )?;
124 Ok(Self {
125 q_proj,
126 k_proj,
127 v_proj,
128 o_proj,
129 num_heads: num_heads / comm.world_size(),
130 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131 head_dim,
132 rotary_emb,
133 paged_attn,
134 sdpa_params: SdpaParams {
135 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
136 cfg.num_key_value_heads,
137 cfg.num_attention_heads,
138 comm,
139 ),
140 use_flash_attn: cfg.use_flash_attn,
141 softcap: None,
142 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
143 sliding_window: cfg.sliding_window,
144 },
145 })
146 }
147
148 #[allow(clippy::too_many_arguments)]
149 fn forward(
150 &self,
151 xs: &Tensor,
152 attention_mask: Option<&Tensor>,
153 seqlen_offsets: &[usize],
154 kv_cache: &mut KvCache,
155 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
156 flash_params: &FlashParams,
157 ) -> Result<Tensor> {
158 let (b_sz, q_len, _) = xs.dims3()?;
159
160 let original_dtype = xs.dtype();
161 let mut xs = xs.clone();
162 if let Some(t) = self.q_proj.quantized_act_type() {
163 xs = xs.to_dtype(t)?;
164 }
165 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
166 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
167 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
168 if self.q_proj.quantized_act_type().is_some() {
169 q = q.to_dtype(original_dtype)?;
170 k = k.to_dtype(original_dtype)?;
171 v = v.to_dtype(original_dtype)?;
172 }
173
174 let (q, k, v) = if q_len != 1 {
175 let q = q
176 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
177 .transpose(1, 2)?;
178 let k = k
179 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
180 .transpose(1, 2)?;
181 let v = v
182 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
183 .transpose(1, 2)?;
184 (q, k, v)
185 } else {
186 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
187 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
188 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
189 (q, k, v)
190 };
191
192 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
193
194 let mut attn_output = match &self.paged_attn {
195 Some(paged_attn) => match metadata {
196 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
197 &q,
198 &k,
199 &v,
200 attention_mask,
201 Some(key_cache),
202 Some(value_cache),
203 input_metadata,
204 &self.sdpa_params,
205 Some(flash_params),
206 )?,
207 None => {
208 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
211 assert!(attention_mask.is_some());
213 paged_attn.forward(
214 &q,
215 &k,
216 &v,
217 attention_mask,
218 None,
219 None,
220 &input_metadata,
221 &self.sdpa_params,
222 Some(flash_params),
223 )?
224 }
225 },
226 None => {
227 let (k, v) = kv_cache.append(&k, &v)?;
228
229 Sdpa.run_attention(
230 &q,
231 &k,
232 &v,
233 attention_mask,
234 Some(flash_params),
235 &self.sdpa_params,
236 )?
237 }
238 };
239
240 if let Some(t) = self.q_proj.quantized_act_type() {
241 attn_output = attn_output.to_dtype(t)?;
242 }
243 attn_output = if attention_mask.is_some() {
244 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
245 } else {
246 attn_output.reshape((b_sz, q_len, ()))?
247 };
248 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
249 if self.q_proj.quantized_act_type().is_some() {
250 res = res.to_dtype(original_dtype)?;
251 }
252 Ok(res)
253 }
254}
255
256struct DecoderLayer {
257 self_attn: Attention,
258 mlp: Box<dyn MlpLayer>,
259 input_layernorm: RmsNorm,
260 post_attention_layernorm: RmsNorm,
261}
262
263impl DecoderLayer {
264 #[allow(clippy::too_many_arguments)]
265 fn new(
266 rotary_emb: Arc<RotaryEmbedding>,
267 cfg: &Config,
268 vb: ShardedVarBuilder,
269 mapper: &dyn DeviceMapper,
270 layer_idx: usize,
271 loading_isq: bool,
272 paged_attn: Option<PagedAttention>,
273 comm: &Arc<mistralrs_quant::Comm>,
274 ) -> Result<Self> {
275 let self_attn = Attention::new(
276 rotary_emb,
277 cfg,
278 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279 paged_attn,
280 comm,
281 )?;
282 let mlp = Mlp::new(
283 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
284 cfg.hidden_size,
285 cfg.intermediate_size,
286 &cfg.quantization_config,
287 cfg.hidden_act,
288 comm,
289 )?;
290 let input_layernorm = RmsNorm::new(
291 cfg.hidden_size,
292 cfg.rms_norm_eps,
293 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
294 )?;
295 let post_attention_layernorm = RmsNorm::new(
296 cfg.hidden_size,
297 cfg.rms_norm_eps,
298 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
299 )?;
300 Ok(Self {
301 self_attn,
302 mlp: Box::new(mlp),
303 input_layernorm,
304 post_attention_layernorm,
305 })
306 }
307
308 #[allow(clippy::too_many_arguments)]
309 fn forward(
310 &self,
311 xs: &Tensor,
312 attention_mask: Option<&Tensor>,
313 seqlen_offsets: &[usize],
314 kv_cache: &mut KvCache,
315 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
316 flash_params: &FlashParams,
317 ) -> Result<Tensor> {
318 let residual = xs;
319 let xs = self.input_layernorm.forward(xs)?;
320 let xs = self.self_attn.forward(
321 &xs,
322 attention_mask,
323 seqlen_offsets,
324 kv_cache,
325 metadata,
326 flash_params,
327 )?;
328 let xs = (xs + residual)?;
329 let residual = &xs;
330 let xs = self
331 .mlp
332 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
333 residual + xs
334 }
335}
336
337pub struct Model {
338 embed_tokens: candle_nn::Embedding,
339 layers: Vec<DecoderLayer>,
340 norm: RmsNorm,
341 lm_head: Arc<dyn QuantMethod>,
342 sliding_window: Option<usize>,
343 device: Device,
344 pub(crate) cache: EitherCache,
345 max_seq_len: usize,
346 mapper: Box<dyn DeviceMapper + Send + Sync>,
347 cfg: ModelConfigMetadata,
348}
349
350impl Model {
351 pub fn new(
352 cfg: &Config,
353 vb: ShardedVarBuilder,
354 is_gptx: bool,
355 normal_loading_metadata: NormalLoadingMetadata,
356 attention_mechanism: AttentionImplementation,
357 ) -> Result<Self> {
358 let vb_m = vb.pp("model");
359 let vb_lm_head = vb.pp("lm_head");
360 Self::new_inner(
361 cfg,
362 vb_m,
363 vb_lm_head,
364 is_gptx,
365 normal_loading_metadata,
366 attention_mechanism,
367 )
368 }
369
370 pub fn new_inner(
371 cfg: &Config,
372 vb_m: ShardedVarBuilder,
373 vb_lm_head: ShardedVarBuilder,
374 is_gptx: bool,
375 normal_loading_metadata: NormalLoadingMetadata,
376 attention_mechanism: AttentionImplementation,
377 ) -> Result<Self> {
378 if let Some(ref quant_cfg) = &cfg.quantization_config {
379 tracing::info!(
380 "Using {} quantization: {}.",
381 quant_cfg.quant_method.to_string(),
382 quant_cfg.get_bits_name(&vb_m)
383 );
384 }
385 let mapper = normal_loading_metadata.mapper;
386
387 let embed_tokens = embedding(
388 cfg.vocab_size,
389 cfg.hidden_size,
390 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
391 )?;
392
393 let head_dim = cfg.head_dim();
394 let mut ropes = HashMap::new();
395 for layer_idx in 0..cfg.num_hidden_layers {
396 let device = mapper
397 .device_for(layer_idx, false)
398 .unwrap_or(&normal_loading_metadata.real_device);
399 ropes.insert(
400 device.location(),
401 Arc::new(RotaryEmbedding::new(
402 cfg.rope_theta as f32,
403 head_dim,
404 cfg.max_position_embeddings,
405 device,
406 is_gptx,
407 vb_m.dtype(),
408 )?),
409 );
410 }
411
412 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
413 let vb_l = vb_m.pp("layers");
414 for layer_idx in NiceProgressBar::<_, 'b'>(
415 0..cfg.num_hidden_layers,
416 "Loading repeating layers",
417 &normal_loading_metadata.multi_progress,
418 ) {
419 let device = mapper
420 .device_for(layer_idx, false)
421 .unwrap_or(&normal_loading_metadata.real_device);
422 let rotary_emb = ropes
423 .get(&device.location())
424 .expect("No RoPE for device location!")
425 .clone();
426 let paged_attn = match &attention_mechanism {
427 AttentionImplementation::Eager => None,
428 AttentionImplementation::PagedAttention => {
429 Some(PagedAttention::new(head_dim, device, None)?)
430 }
431 };
432 let comm = mapper.get_comm_for(layer_idx)?;
433 let layer = DecoderLayer::new(
434 rotary_emb.clone(),
435 cfg,
436 vb_l.pp(layer_idx),
437 &*mapper,
438 layer_idx,
439 normal_loading_metadata.loading_isq,
440 paged_attn,
441 &comm,
442 )?;
443 layers.push(layer)
444 }
445 let norm = RmsNorm::new(
446 cfg.hidden_size,
447 cfg.rms_norm_eps,
448 mapper.set_nm_device(vb_m.pp("norm"), false),
449 )?;
450 let lm_head = if !cfg.tie_word_embeddings {
451 ReplicatedLayer::new(
452 cfg.hidden_size,
453 cfg.vocab_size,
454 &None,
455 false,
456 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
457 )?
458 } else {
459 ReplicatedLayer::from_linear(candle_nn::Linear::new(
460 mapper.cast_nm_device(
461 embed_tokens.embeddings(),
462 normal_loading_metadata.loading_isq,
463 )?,
464 None,
465 ))?
466 };
467 Ok(Self {
468 embed_tokens,
469 layers,
470 norm,
471 lm_head,
472 sliding_window: cfg.sliding_window,
473 device: normal_loading_metadata.real_device,
474 cache: EitherCache::Normal(NormalCache::new_sliding(
475 cfg.num_hidden_layers,
476 cfg.max_position_embeddings,
477 cfg.sliding_window,
478 )),
479 max_seq_len: cfg.max_position_embeddings,
480 cfg: ModelConfigMetadata {
481 max_seq_len: cfg.max_position_embeddings,
482 num_layers: cfg.num_hidden_layers,
483 hidden_size: cfg.hidden_size,
484 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
485 .max(1),
486 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
487 sliding_window: cfg.sliding_window,
488 k_head_dim: cfg.head_dim(),
489 v_head_dim: cfg.head_dim(),
490 },
491 mapper,
492 })
493 }
494
495 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
496 self.embed_tokens.forward(input_ids)
497 }
498
499 pub fn forward(
500 &self,
501 input_ids: &Tensor,
502 seqlen_offsets: &[usize],
503 context_lens: Vec<(usize, usize)>,
504 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
505 flash_params: &FlashParams,
506 ) -> Result<Tensor> {
507 self.forward_embeds(
508 input_ids,
509 self.embed_tokens.forward(input_ids)?,
510 seqlen_offsets,
511 context_lens,
512 metadata,
513 flash_params,
514 )
515 }
516
517 #[allow(clippy::too_many_arguments)]
518 pub fn forward_embeds(
519 &self,
520 input_ids: &Tensor,
521 input_embeds: Tensor,
522 seqlen_offsets: &[usize],
523 context_lens: Vec<(usize, usize)>,
524 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
525 flash_params: &FlashParams,
526 ) -> Result<Tensor> {
527 let mut xs = input_embeds;
528 let cache = &mut self.cache.normal().0;
529 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
530 input_ids,
531 metadata
532 .as_ref()
533 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
534 .unwrap_or(cache as &dyn PastKvLenCache),
535 self.sliding_window,
536 xs.dtype(),
537 self.cfg.num_attn_heads,
538 )?;
539 let attention_mask = attention_mask.filter(|_| {
541 metadata
542 .as_ref()
543 .map(|(_, meta)| meta.is_first_prompt_chunk)
544 .unwrap_or(true)
545 });
546 for (i, layer) in self.layers.iter().enumerate() {
547 xs = self.mapper.map(xs, i)?;
548 xs = layer.forward(
549 &xs,
550 attention_mask
551 .as_ref()
552 .map(|m| m.to_device(xs.device()).unwrap())
553 .as_ref(),
554 seqlen_offsets,
555 &mut cache[i],
556 metadata
557 .as_ref()
558 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
559 flash_params,
560 )?;
561 }
562 let xs = xs.to_device(&self.device)?;
563 let mut xs = xs.apply(&self.norm)?;
564 if let Some(t) = self.lm_head.quantized_act_type() {
565 xs = xs.to_dtype(t)?;
566 }
567 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
568 }
569}
570
571impl IsqModel for Model {
572 fn get_layers(
573 &mut self,
574 ) -> (
575 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
576 &dyn DeviceMapper,
577 ) {
578 let mut tensors = Vec::new();
579 tensors.push((&mut self.lm_head, None));
580 for (i, layer) in self.layers.iter_mut().enumerate() {
581 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
582 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
583 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
584 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
585 tensors.extend(
586 layer
587 .mlp
588 .get_isq_layers()
589 .into_iter()
590 .map(|m| (m, Some(i)))
591 .collect::<Vec<_>>(),
592 );
593 }
594 (tensors, &*self.mapper)
595 }
596
597 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
598 let uvb = UnVarBuilder::new();
599
600 let uvb_m = uvb.pp("model");
601 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
602 uvb_m.pp("norm").add(&self.norm);
603
604 for (layer_idx, layer) in self.layers.iter().enumerate() {
605 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
606 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
607 uvb_l
608 .pp("post_attention_layernorm")
609 .add(&layer.post_attention_layernorm);
610 }
611
612 uvb.to_safetensors()
613 }
614
615 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
616 let mut names = Vec::new();
618 names.push(None);
620 for i in 0..self.layers.len() {
621 names.push(Some(format!("blk.{i}.attn_q.weight")));
622 names.push(Some(format!("blk.{i}.attn_k.weight")));
623 names.push(Some(format!("blk.{i}.attn_v.weight")));
624 names.push(Some(format!("blk.{i}.attn_output.weight")));
625 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
626 names.push(Some(format!("blk.{i}.ffn_up.weight")));
627 names.push(Some(format!("blk.{i}.ffn_down.weight")));
628 }
629 Ok(names)
630 }
631}
632
633impl NormalModel for Model {
634 fn forward(
635 &self,
636 input_ids: &Tensor,
637 seqlen_offsets: &[usize],
638 context_lens: Vec<(usize, usize)>,
639 _position_ids: Vec<usize>,
640 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
641 flash_params: &FlashParams,
642 ) -> Result<Tensor> {
643 self.forward(
644 input_ids,
645 seqlen_offsets,
646 context_lens,
647 metadata,
648 flash_params,
649 )
650 }
651 fn xlora_forward(
652 &self,
653 _input_ids: &Tensor,
654 _input_ids_full: &Tensor,
655 _seqlen_offsets: &[usize],
656 _seqlen_offsets_full: &[usize],
657 _no_kv_cache: bool,
658 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
659 _context_lens: Vec<(usize, usize)>,
660 _position_ids: Vec<usize>,
661 _flash_params: &FlashParams,
662 _flash_params_full: &FlashParams,
663 ) -> Result<Tensor> {
664 unimplemented!()
665 }
666 fn cache(&self) -> &EitherCache {
667 &self.cache
668 }
669 fn cache_mut(&mut self) -> &mut EitherCache {
670 &mut self.cache
671 }
672 fn device(&self) -> &Device {
673 &self.device
674 }
675 fn is_xlora(&self) -> bool {
676 false
677 }
678 fn max_seq_len(&self) -> usize {
679 self.max_seq_len
680 }
681 fn config(&self) -> &ModelConfigMetadata {
682 &self.cfg
683 }
684}
685
686impl AnyMoeBaseModelMixin for Model {
687 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
688 let mut mlps = Vec::new();
689 for layer in &self.layers {
690 mlps.push(&*layer.mlp);
691 }
692 mlps
693 }
694 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
695 let mut mlps = Vec::new();
696 for layer in &mut self.layers {
697 mlps.push(&mut layer.mlp);
698 }
699 mlps
700 }
701 fn create_anymoe_layers(
702 &mut self,
703 additional_vbs: Vec<ShardedVarBuilder>,
704 config: AnyMoeConfig,
705 (prefix, mlp): (String, String),
706 mut layers: Vec<usize>,
707 expert_type: AnyMoeExpertType,
708 gate_vb: Option<ShardedVarBuilder>,
709 ) -> Result<()> {
710 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
711 if layers.is_empty() {
712 layers = (0..self.layers.len()).collect::<Vec<_>>();
713 }
714 for _ in 0..layers.len() {
715 experts.push(Vec::new());
716 }
717 for vb in additional_vbs {
718 let vb = vb.pp(&prefix);
719 for (layer, row) in experts.iter_mut().enumerate() {
720 if !layers.contains(&layer) {
721 continue;
722 }
723
724 let intermediate_size = self.layers[layer].mlp.get_params()[1];
725 let hidden_size = self.layers[layer].mlp.get_params()[0];
726 match expert_type {
727 AnyMoeExpertType::FineTuned => {
728 let (dtype, device) = self.layers[layer].mlp.dtype_device();
729 row.push(Box::new(Mlp::replicate(
730 self.layers[layer].mlp.get_params(),
731 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
732 self.layers[layer].mlp.hidden_act(),
733 &self.mapper.get_comm_for(layer)?,
734 )?));
735 }
736 AnyMoeExpertType::LoraAdapter {
737 rank,
738 alpha,
739 ref target_modules,
740 } => {
741 let vb_mlp = vb.pp(layer).pp(&mlp);
742
743 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
744 Some(get_delta_from_lora_ab!(
745 vb_mlp,
746 rank,
747 alpha,
748 (hidden_size, intermediate_size),
749 "gate_proj"
750 ))
751 } else {
752 None
753 };
754 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
755 Some(get_delta_from_lora_ab!(
756 vb_mlp,
757 rank,
758 alpha,
759 (hidden_size, intermediate_size),
760 "up_proj"
761 ))
762 } else {
763 None
764 };
765 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
766 Some(get_delta_from_lora_ab!(
767 vb_mlp,
768 rank,
769 alpha,
770 (intermediate_size, hidden_size),
771 "down_proj"
772 ))
773 } else {
774 None
775 };
776
777 row.push(self.layers[layer].mlp.new_added_delta(vec![
778 gate_proj_delta,
779 up_proj_delta,
780 down_proj_delta,
781 ])?);
782 }
783 }
784 }
785 }
786 for (layer, expert) in layers.into_iter().zip(experts) {
787 let mut experts_all = vec![self.layers[layer].mlp.clone()];
788 experts_all.extend(expert);
789 let (dtype, device) = self.layers[layer].mlp.dtype_device();
790 self.layers[layer].mlp = Box::new(MoeMlp::new(
791 experts_all,
792 config.clone(),
793 dtype,
794 &device,
795 layer,
796 gate_vb.as_ref(),
797 )?);
798 }
799 Ok(())
800 }
801 fn amoe_supported(&self) -> bool {
802 true
803 }
804}