1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
9 UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{self, linear_no_bias, Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
24 },
25 utils::progress::NiceProgressBar,
26 AnyMoeConfig, AnyMoeExpertType,
27};
28
29use super::{LLaVALLM, OrdinaryRoPE};
30use crate::models::mistral::Config;
31
32#[derive(Clone)]
33#[allow(clippy::upper_case_acronyms)]
34struct MLP {
35 gate_proj: Arc<dyn QuantMethod>,
36 up_proj: Arc<dyn QuantMethod>,
37 down_proj: Arc<dyn QuantMethod>,
38 act_fn: Activation,
39 params: Vec<usize>,
40}
41
42impl MLP {
43 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
44 let hidden_sz = cfg.hidden_size;
45 let intermediate_sz = cfg.intermediate_size;
46 let gate_proj = ColumnParallelLayer::new(
47 hidden_sz,
48 intermediate_sz,
49 &cfg.quantization_config,
50 false,
51 comm,
52 vb.pp("gate_proj"),
53 )?;
54 let up_proj = ColumnParallelLayer::new(
55 hidden_sz,
56 intermediate_sz,
57 &cfg.quantization_config,
58 false,
59 comm,
60 vb.pp("up_proj"),
61 )?;
62 let down_proj = RowParallelLayer::new(
63 intermediate_sz,
64 hidden_sz,
65 &cfg.quantization_config,
66 false,
67 comm,
68 vb.pp("down_proj"),
69 )?;
70 Ok(Self {
71 gate_proj,
72 up_proj,
73 down_proj,
74 act_fn: cfg.hidden_act,
75 params: vec![hidden_sz, intermediate_sz],
76 })
77 }
78}
79
80impl AnyMoeTrainableLayer for MLP {}
81
82impl MlpLayer for MLP {
83 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
84 let original_dtype = xs.dtype();
85 let mut xs = xs.clone();
86 if let Some(t) = self.gate_proj.quantized_act_type() {
87 xs = xs.to_dtype(t)?;
88 }
89 let lhs = MatMul
90 .qmethod_matmul(&xs, &*self.gate_proj)?
91 .apply(&self.act_fn)?;
92 let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
93 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?;
94 if self.gate_proj.quantized_act_type().is_some() {
95 res = res.to_dtype(original_dtype)?;
96 }
97 Ok(res)
98 }
99 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
100 vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj]
101 }
102 fn clone(&self) -> Box<dyn MlpLayer> {
103 Box::new(Clone::clone(self))
104 }
105 fn get_params(&self) -> &[usize] {
106 &self.params
107 }
108 fn hidden_act(&self) -> Activation {
109 self.act_fn
110 }
111 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
113 let gate_proj = if let Some(ref delta) = deltas[0] {
114 self.gate_proj.add_delta_w(delta)?
115 } else {
116 self.gate_proj.clone()
117 };
118 let up_proj = if let Some(ref delta) = deltas[1] {
119 self.up_proj.add_delta_w(delta)?
120 } else {
121 self.up_proj.clone()
122 };
123 let down_proj = if let Some(ref delta) = deltas[2] {
124 self.down_proj.add_delta_w(delta)?
125 } else {
126 self.down_proj.clone()
127 };
128
129 Ok(Box::new(Self {
130 gate_proj,
131 up_proj,
132 down_proj,
133 act_fn: self.act_fn,
134 params: self.params.clone(),
135 }))
136 }
137
138 fn dtype_device(&self) -> (DType, Device) {
139 self.gate_proj.dtype_and_device()
140 }
141}
142
143struct Attention {
144 q_proj: Arc<dyn QuantMethod>,
145 k_proj: Arc<dyn QuantMethod>,
146 v_proj: Arc<dyn QuantMethod>,
147 o_proj: Arc<dyn QuantMethod>,
148 num_heads: usize,
149 num_kv_heads: usize,
150 head_dim: usize,
151 sliding_window: Option<usize>,
152 paged_attn: Option<PagedAttention>,
153 sdpa_params: SdpaParams,
154}
155
156impl Attention {
157 fn new(
158 cfg: &Config,
159 vb: ShardedVarBuilder,
160 paged_attn: Option<PagedAttention>,
161 comm: &Arc<mistralrs_quant::Comm>,
162 ) -> Result<Self> {
163 let hidden_sz = cfg.hidden_size;
164 let num_heads = cfg.num_attention_heads;
165 let num_kv_heads = cfg.num_key_value_heads;
166 let head_dim = cfg.head_dim();
167 let q_proj = ColumnParallelLayer::new(
168 hidden_sz,
169 num_heads * head_dim,
170 &cfg.quantization_config,
171 false,
172 comm,
173 vb.pp("q_proj"),
174 )?;
175 let kv_shard = mistralrs_quant::compute_kv_shard(
176 cfg.num_key_value_heads,
177 cfg.hidden_size / cfg.num_attention_heads,
178 comm,
179 );
180 let k_proj = ColumnParallelLayer::new_with_shard(
181 hidden_sz,
182 num_kv_heads * head_dim,
183 &cfg.quantization_config,
184 false,
185 comm,
186 kv_shard,
187 vb.pp("k_proj"),
188 )?;
189 let v_proj = ColumnParallelLayer::new_with_shard(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 &cfg.quantization_config,
193 false,
194 comm,
195 kv_shard,
196 vb.pp("v_proj"),
197 )?;
198 let o_proj = RowParallelLayer::new(
199 num_heads * head_dim,
200 hidden_sz,
201 &cfg.quantization_config,
202 false,
203 comm,
204 vb.pp("o_proj"),
205 )?;
206 Ok(Self {
207 q_proj,
208 k_proj,
209 v_proj,
210 o_proj,
211 num_heads: num_heads / comm.world_size(),
212 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
213 head_dim,
214 sliding_window: cfg.sliding_window,
215 paged_attn,
216 sdpa_params: SdpaParams {
217 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
218 cfg.num_key_value_heads,
219 cfg.num_attention_heads,
220 comm,
221 ),
222 softcap: None,
223 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
224 sliding_window: cfg.sliding_window,
225 },
226 })
227 }
228
229 #[allow(clippy::too_many_arguments)]
230 fn forward(
231 &self,
232 xs: &Tensor,
233 attention_mask: Option<&Tensor>,
234 seqlen_offsets: &[usize],
235 kv_cache: &mut Option<(Tensor, Tensor)>,
236 rope_parameter: (&Tensor, &Tensor),
237 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
238 flash_params: &FlashParams,
239 ) -> Result<Tensor> {
240 let (b_sz, q_len, _) = xs.dims3()?;
241
242 let original_dtype = xs.dtype();
243 let mut xs = xs.clone();
244 if let Some(t) = self.q_proj.quantized_act_type() {
245 xs = xs.to_dtype(t)?;
246 }
247 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
248 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
249 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
250 if self.q_proj.quantized_act_type().is_some() {
251 q = q.to_dtype(original_dtype)?;
252 k = k.to_dtype(original_dtype)?;
253 v = v.to_dtype(original_dtype)?;
254 }
255
256 let mut q = q
257 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
258 .transpose(1, 2)?
259 .contiguous()?;
260 let mut k = k
261 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
262 .transpose(1, 2)?
263 .contiguous()?;
264 q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
265 k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
266 let v = v
267 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
268 .transpose(1, 2)?;
269
270 let mut attn_output = match &self.paged_attn {
271 Some(paged_attn) => match metadata {
272 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
273 &q,
274 &k,
275 &v,
276 attention_mask,
277 Some(key_cache),
278 Some(value_cache),
279 input_metadata,
280 &self.sdpa_params,
281 Some(flash_params),
282 )?,
283 None => {
284 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
287 assert!(attention_mask.is_some());
289 paged_attn.forward(
290 &q,
291 &k,
292 &v,
293 attention_mask,
294 None,
295 None,
296 &input_metadata,
297 &self.sdpa_params,
298 Some(flash_params),
299 )?
300 }
301 },
302 None => {
303 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
304 kv_cache,
305 k,
306 v,
307 attention_mask,
308 self.sliding_window,
309 )?;
310
311 Sdpa.run_attention(
312 &q,
313 &k,
314 &v,
315 attn_mask.as_ref(),
316 Some(flash_params),
317 &self.sdpa_params,
318 )?
319 }
320 };
321
322 if let Some(t) = self.q_proj.quantized_act_type() {
323 attn_output = attn_output.to_dtype(t)?;
324 }
325 attn_output = if attention_mask.is_some() {
326 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
327 } else {
328 attn_output.reshape((b_sz, q_len, ()))?
329 };
330 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
331 if self.q_proj.quantized_act_type().is_some() {
332 res = res.to_dtype(original_dtype)?;
333 }
334 Ok(res)
335 }
336}
337
338struct DecoderLayer {
339 self_attn: Attention,
340 mlp: Box<dyn MlpLayer>,
341 input_layernorm: RmsNorm,
342 post_attention_layernorm: RmsNorm,
343 rope_parameter: (Tensor, Tensor),
344}
345
346impl DecoderLayer {
347 #[allow(clippy::too_many_arguments)]
348 fn new(
349 cfg: &Config,
350 vb: ShardedVarBuilder,
351 mapper: &dyn DeviceMapper,
352 layer_idx: usize,
353 loading_isq: bool,
354 paged_attn: Option<PagedAttention>,
355 rope_parameter: (Tensor, Tensor),
356 comm: &Arc<mistralrs_quant::Comm>,
357 ) -> Result<Self> {
358 let self_attn = Attention::new(
359 cfg,
360 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
361 paged_attn,
362 comm,
363 )?;
364 let mlp = MLP::new(
365 cfg,
366 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
367 comm,
368 )?;
369 let input_layernorm = RmsNorm::new(
370 cfg.hidden_size,
371 cfg.rms_norm_eps,
372 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
373 )?;
374 let post_attention_layernorm = RmsNorm::new(
375 cfg.hidden_size,
376 cfg.rms_norm_eps,
377 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
378 )?;
379 Ok(Self {
380 self_attn,
381 mlp: Box::new(mlp),
382 input_layernorm,
383 post_attention_layernorm,
384 rope_parameter,
385 })
386 }
387
388 #[allow(clippy::too_many_arguments)]
389 fn forward(
390 &self,
391 xs: &Tensor,
392 attention_mask: Option<&Tensor>,
393 seqlen_offsets: &[usize],
394 kv_cache: &mut Option<(Tensor, Tensor)>,
395 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
396 flash_params: &FlashParams,
397 ) -> Result<Tensor> {
398 let residual = xs;
399 let mut xs = self.input_layernorm.forward(xs)?;
400 xs = self.self_attn.forward(
401 &xs,
402 attention_mask,
403 seqlen_offsets,
404 kv_cache,
405 (&self.rope_parameter.0, &self.rope_parameter.1),
406 metadata,
407 flash_params,
408 )?;
409 xs = (xs + residual)?;
410 let residual = &xs;
411 let xs = self
412 .mlp
413 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
414 residual + xs
415 }
416}
417
418pub struct Model {
419 embed_tokens: candle_nn::Embedding,
420 layers: Vec<DecoderLayer>,
421 norm: RmsNorm,
422 lm_head: Arc<dyn QuantMethod>,
423 sliding_window: Option<usize>,
424 device: Device,
425 cache: EitherCache,
426 max_seq_len: usize,
427 mapper: Box<dyn DeviceMapper + Send + Sync>,
428 cfg: ModelConfigMetadata,
429}
430
431impl Model {
432 pub fn new(
433 cfg: &Config,
434 vb: ShardedVarBuilder,
435 is_gptx: bool,
436 normal_loading_metadata: NormalLoadingMetadata,
437 attention_mechanism: AttentionImplementation,
438 ) -> Result<Self> {
439 let vb_m = vb.pp("model");
440 let vb_lm_head = vb.pp("lm_head");
441 Self::new_inner(
442 cfg,
443 vb_m,
444 vb_lm_head,
445 is_gptx,
446 normal_loading_metadata,
447 attention_mechanism,
448 )
449 }
450
451 pub fn new_inner(
452 cfg: &Config,
453 vb_m: ShardedVarBuilder,
454 vb_lm_head: ShardedVarBuilder,
455 _is_gptx: bool,
456 normal_loading_metadata: NormalLoadingMetadata,
457 attention_mechanism: AttentionImplementation,
458 ) -> Result<Self> {
459 if let Some(ref quant_cfg) = &cfg.quantization_config {
460 tracing::info!(
461 "Using {} quantization: {}.",
462 quant_cfg.name(),
463 quant_cfg.get_bits_name(&vb_m)
464 );
465 }
466 let mapper = normal_loading_metadata.mapper;
467 let embed_tokens = layers::embedding(
468 cfg.vocab_size,
469 cfg.hidden_size,
470 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
471 &cfg.quantization_config,
472 )?;
473 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
474 let vb_l = vb_m.pp("layers");
475 let layers = NiceProgressBar::<_, 'b'>(
476 0..cfg.num_hidden_layers,
477 "Loading repeating layers",
478 &normal_loading_metadata.multi_progress,
479 )
480 .par_iter_if_isq(|layer_idx| {
481 let device = mapper
482 .device_for(layer_idx, false)
483 .unwrap_or(&normal_loading_metadata.real_device);
484 let rope_parameters = OrdinaryRoPE::create_parameters(
485 head_dim,
486 cfg.max_position_embeddings,
487 cfg.rope_theta as f32,
488 vb_m.dtype(),
489 device,
490 )?;
491 let paged_attn = match &attention_mechanism {
492 AttentionImplementation::Eager => None,
493 AttentionImplementation::PagedAttention => {
494 Some(PagedAttention::new(head_dim, device, None)?)
495 }
496 };
497 let comm = mapper.get_comm_for(layer_idx)?;
498 DecoderLayer::new(
499 cfg,
500 vb_l.pp(layer_idx),
501 &*mapper,
502 layer_idx,
503 normal_loading_metadata.loading_isq,
504 paged_attn,
505 rope_parameters,
506 &comm,
507 )
508 })?;
509 let norm = RmsNorm::new(
510 cfg.hidden_size,
511 cfg.rms_norm_eps,
512 mapper.set_nm_device(vb_m.pp("norm"), false),
513 )?;
514 let lm_head = linear_no_bias(
515 cfg.hidden_size,
516 cfg.vocab_size,
517 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
518 )?;
519 Ok(Self {
520 embed_tokens,
521 layers,
522 norm,
523 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
524 sliding_window: cfg.sliding_window,
525 device: normal_loading_metadata.real_device,
526 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
527 max_seq_len: cfg.max_position_embeddings,
528 cfg: ModelConfigMetadata {
529 max_seq_len: cfg.max_position_embeddings,
530 num_layers: cfg.num_hidden_layers,
531 hidden_size: cfg.hidden_size,
532 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
533 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
534 .max(1),
535 sliding_window: cfg.sliding_window,
536 k_head_dim: cfg.head_dim(),
537 v_head_dim: cfg.head_dim(),
538 },
539 mapper,
540 })
541 }
542
543 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
544 self.embed_tokens.forward(input_ids)
545 }
546
547 pub fn forward(
548 &self,
549 input_ids: &Tensor,
550 seqlen_offsets: &[usize],
551 context_lens: Vec<(usize, usize)>,
552 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
553 flash_params: &FlashParams,
554 ) -> Result<Tensor> {
555 self.forward_embeds(
556 input_ids,
557 self.embed_tokens.forward(input_ids)?,
558 seqlen_offsets,
559 context_lens,
560 metadata,
561 flash_params,
562 )
563 }
564
565 #[allow(clippy::too_many_arguments)]
566 pub fn forward_embeds(
567 &self,
568 input_ids: &Tensor,
569 input_embeds: Tensor,
570 seqlen_offsets: &[usize],
571 context_lens: Vec<(usize, usize)>,
572 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
573 flash_params: &FlashParams,
574 ) -> Result<Tensor> {
575 let mut xs = input_embeds;
576 let mut cache = self.cache.full().lock();
577 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
578 input_ids,
579 metadata
580 .as_ref()
581 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
582 .unwrap_or(&*cache as &dyn PastKvLenCache),
583 self.sliding_window,
584 xs.dtype(),
585 self.cfg.num_attn_heads,
586 )?;
587 let attention_mask = attention_mask.filter(|_| {
588 metadata
589 .as_ref()
590 .map(|(_, meta)| meta.is_first_prompt_chunk)
591 .unwrap_or(true)
592 });
593 for (i, layer) in self.layers.iter().enumerate() {
594 xs = self.mapper.map(xs, i)?;
595 xs = layer.forward(
596 &xs,
597 attention_mask
598 .as_ref()
599 .map(|m| m.to_device(xs.device()).unwrap())
600 .as_ref(),
601 seqlen_offsets,
602 &mut cache[i],
603 metadata
604 .as_ref()
605 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
606 flash_params,
607 )?;
608 }
609 xs = xs.to_device(&self.device)?;
610 xs = xs.apply(&self.norm)?;
611 if let Some(t) = self.lm_head.quantized_act_type() {
612 xs = xs.to_dtype(t)?;
613 }
614 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
615 }
616}
617
618impl IsqModel for Model {
619 fn get_layers(
620 &mut self,
621 ) -> (
622 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
623 &dyn DeviceMapper,
624 ) {
625 let mut tensors = Vec::new();
626 tensors.push((&mut self.lm_head, None));
627 for (i, layer) in self.layers.iter_mut().enumerate() {
628 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
629 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
630 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
631 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
632 tensors.extend(
633 layer
634 .mlp
635 .get_isq_layers()
636 .into_iter()
637 .map(|m| (m, Some(i)))
638 .collect::<Vec<_>>(),
639 );
640 }
641 (tensors, &*self.mapper)
642 }
643
644 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
645 Vec::new()
646 }
647}
648
649impl LLaVALLM for Model {
650 fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
651 self.get_input_embeddings(input_ids)
652 }
653
654 fn forward_input_embed(
655 &self,
656 input_ids: &Tensor,
657 input_embed: Tensor,
658 seqlen_offsets: &[usize],
659 context_lens: Vec<(usize, usize)>,
660 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
661 flash_params: &FlashParams,
662 ) -> Result<Tensor> {
663 self.forward_embeds(
664 input_ids,
665 input_embed,
666 seqlen_offsets,
667 context_lens,
668 metadata,
669 flash_params,
670 )
671 }
672}
673
674impl NormalModel for Model {
675 fn forward(
676 &self,
677 input_ids: &Tensor,
678 seqlen_offsets: &[usize],
679 context_lens: Vec<(usize, usize)>,
680 _position_ids: Vec<usize>,
681 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
682 flash_params: &FlashParams,
683 ) -> Result<Tensor> {
684 self.forward(
685 input_ids,
686 seqlen_offsets,
687 context_lens,
688 metadata,
689 flash_params,
690 )
691 }
692 fn xlora_forward(
693 &self,
694 _input_ids: &Tensor,
695 _input_ids_full: &Tensor,
696 _seqlen_offsets: &[usize],
697 _seqlen_offsets_full: &[usize],
698 _no_kv_cache: bool,
699 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
700 _context_lens: Vec<(usize, usize)>,
701 _position_ids: Vec<usize>,
702 _flash_params: &FlashParams,
703 _flash_params_full: &FlashParams,
704 ) -> Result<Tensor> {
705 unimplemented!()
706 }
707 fn cache(&self) -> &EitherCache {
708 &self.cache
709 }
710 fn cache_mut(&mut self) -> &mut EitherCache {
711 &mut self.cache
712 }
713 fn device(&self) -> &Device {
714 &self.device
715 }
716 fn is_xlora(&self) -> bool {
717 false
718 }
719 fn max_seq_len(&self) -> usize {
720 self.max_seq_len
721 }
722 fn config(&self) -> &ModelConfigMetadata {
723 &self.cfg
724 }
725}
726
727impl AnyMoeBaseModelMixin for Model {
728 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
729 let mut mlps = Vec::new();
730 for layer in &self.layers {
731 mlps.push(&*layer.mlp);
732 }
733 mlps
734 }
735 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
736 let mut mlps = Vec::new();
737 for layer in &mut self.layers {
738 mlps.push(&mut layer.mlp);
739 }
740 mlps
741 }
742 fn create_anymoe_layers(
743 &mut self,
744 additional_vbs: Vec<ShardedVarBuilder>,
745 config: AnyMoeConfig,
746 (prefix, mlp): (String, String),
747 mut layers: Vec<usize>,
748 expert_type: AnyMoeExpertType,
749 gate_vb: Option<ShardedVarBuilder>,
750 ) -> Result<()> {
751 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
752 if layers.is_empty() {
753 layers = (0..self.layers.len()).collect::<Vec<_>>();
754 }
755 for _ in 0..layers.len() {
756 experts.push(Vec::new());
757 }
758 for vb in additional_vbs {
759 let vb = vb.pp(&prefix);
760 for (layer, row) in experts.iter_mut().enumerate() {
761 if !layers.contains(&layer) {
762 continue;
763 }
764
765 let intermediate_size = self.layers[layer].mlp.get_params()[1];
766 let hidden_size = self.layers[layer].mlp.get_params()[0];
767 match expert_type {
768 AnyMoeExpertType::FineTuned => {
769 let (dtype, device) = self.layers[layer].mlp.dtype_device();
770 row.push(Box::new(MLP::new(
771 &Config {
772 intermediate_size: self.layers[layer].mlp.get_params()[1],
773 hidden_size: self.layers[layer].mlp.get_params()[0],
774 ..Default::default()
775 },
776 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
777 &self.mapper.get_comm_for(layer)?,
778 )?));
779 }
780 AnyMoeExpertType::LoraAdapter {
781 rank,
782 alpha,
783 ref target_modules,
784 } => {
785 let vb_mlp = vb.pp(layer).pp(&mlp);
786
787 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
788 Some(get_delta_from_lora_ab!(
789 vb_mlp,
790 rank,
791 alpha,
792 (hidden_size, intermediate_size),
793 "gate_proj"
794 ))
795 } else {
796 None
797 };
798 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
799 Some(get_delta_from_lora_ab!(
800 vb_mlp,
801 rank,
802 alpha,
803 (hidden_size, intermediate_size),
804 "up_proj"
805 ))
806 } else {
807 None
808 };
809 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
810 Some(get_delta_from_lora_ab!(
811 vb_mlp,
812 rank,
813 alpha,
814 (intermediate_size, hidden_size),
815 "down_proj"
816 ))
817 } else {
818 None
819 };
820
821 row.push(self.layers[layer].mlp.new_added_delta(vec![
822 gate_proj_delta,
823 up_proj_delta,
824 down_proj_delta,
825 ])?);
826 }
827 }
828 }
829 }
830 for (layer, expert) in layers.into_iter().zip(experts) {
831 let mut experts_all = vec![self.layers[layer].mlp.clone()];
832 experts_all.extend(expert);
833 let (dtype, device) = self.layers[layer].mlp.dtype_device();
834 self.layers[layer].mlp = Box::new(MoeMlp::new(
835 experts_all,
836 config.clone(),
837 dtype,
838 &device,
839 layer,
840 gate_vb.as_ref(),
841 )?);
842 }
843 Ok(())
844 }
845 fn amoe_supported(&self) -> bool {
846 true
847 }
848}