1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
9 UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{self, linear_no_bias, Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
24 },
25 utils::progress::NiceProgressBar,
26 AnyMoeConfig, AnyMoeExpertType,
27};
28
29use super::{LLaVALLM, OrdinaryRoPE};
30use crate::models::mistral::Config;
31
32#[derive(Clone)]
33#[allow(clippy::upper_case_acronyms)]
34struct MLP {
35 gate_proj: Arc<dyn QuantMethod>,
36 up_proj: Arc<dyn QuantMethod>,
37 down_proj: Arc<dyn QuantMethod>,
38 act_fn: Activation,
39 params: Vec<usize>,
40}
41
42impl MLP {
43 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
44 let hidden_sz = cfg.hidden_size;
45 let intermediate_sz = cfg.intermediate_size;
46 let gate_proj = ColumnParallelLayer::new(
47 hidden_sz,
48 intermediate_sz,
49 &cfg.quantization_config,
50 false,
51 comm,
52 vb.pp("gate_proj"),
53 )?;
54 let up_proj = ColumnParallelLayer::new(
55 hidden_sz,
56 intermediate_sz,
57 &cfg.quantization_config,
58 false,
59 comm,
60 vb.pp("up_proj"),
61 )?;
62 let down_proj = RowParallelLayer::new(
63 intermediate_sz,
64 hidden_sz,
65 &cfg.quantization_config,
66 false,
67 comm,
68 vb.pp("down_proj"),
69 )?;
70 Ok(Self {
71 gate_proj,
72 up_proj,
73 down_proj,
74 act_fn: cfg.hidden_act,
75 params: vec![hidden_sz, intermediate_sz],
76 })
77 }
78}
79
80impl AnyMoeTrainableLayer for MLP {}
81
82impl MlpLayer for MLP {
83 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
84 let original_dtype = xs.dtype();
85 let mut xs = xs.clone();
86 if let Some(t) = self.gate_proj.quantized_act_type() {
87 xs = xs.to_dtype(t)?;
88 }
89 let lhs = MatMul
90 .qmethod_matmul(&xs, &*self.gate_proj)?
91 .apply(&self.act_fn)?;
92 let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
93 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?;
94 if self.gate_proj.quantized_act_type().is_some() {
95 res = res.to_dtype(original_dtype)?;
96 }
97 Ok(res)
98 }
99 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
100 vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj]
101 }
102 fn clone(&self) -> Box<dyn MlpLayer> {
103 Box::new(Clone::clone(self))
104 }
105 fn get_params(&self) -> &[usize] {
106 &self.params
107 }
108 fn hidden_act(&self) -> Activation {
109 self.act_fn
110 }
111 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
113 let gate_proj = if let Some(ref delta) = deltas[0] {
114 self.gate_proj.add_delta_w(delta)?
115 } else {
116 self.gate_proj.clone()
117 };
118 let up_proj = if let Some(ref delta) = deltas[1] {
119 self.up_proj.add_delta_w(delta)?
120 } else {
121 self.up_proj.clone()
122 };
123 let down_proj = if let Some(ref delta) = deltas[2] {
124 self.down_proj.add_delta_w(delta)?
125 } else {
126 self.down_proj.clone()
127 };
128
129 Ok(Box::new(Self {
130 gate_proj,
131 up_proj,
132 down_proj,
133 act_fn: self.act_fn,
134 params: self.params.clone(),
135 }))
136 }
137
138 fn dtype_device(&self) -> (DType, Device) {
139 self.gate_proj.dtype_and_device()
140 }
141}
142
143struct Attention {
144 q_proj: Arc<dyn QuantMethod>,
145 k_proj: Arc<dyn QuantMethod>,
146 v_proj: Arc<dyn QuantMethod>,
147 o_proj: Arc<dyn QuantMethod>,
148 num_heads: usize,
149 num_kv_heads: usize,
150 head_dim: usize,
151 sliding_window: Option<usize>,
152 paged_attn: Option<PagedAttention>,
153 sdpa_params: SdpaParams,
154}
155
156impl Attention {
157 fn new(
158 cfg: &Config,
159 vb: ShardedVarBuilder,
160 paged_attn: Option<PagedAttention>,
161 comm: &Arc<mistralrs_quant::Comm>,
162 ) -> Result<Self> {
163 let hidden_sz = cfg.hidden_size;
164 let num_heads = cfg.num_attention_heads;
165 let num_kv_heads = cfg.num_key_value_heads;
166 let head_dim = cfg.head_dim();
167 let q_proj = ColumnParallelLayer::new(
168 hidden_sz,
169 num_heads * head_dim,
170 &cfg.quantization_config,
171 false,
172 comm,
173 vb.pp("q_proj"),
174 )?;
175 let kv_shard = mistralrs_quant::compute_kv_shard(
176 cfg.num_key_value_heads,
177 cfg.hidden_size / cfg.num_attention_heads,
178 comm,
179 );
180 let k_proj = ColumnParallelLayer::new_with_shard(
181 hidden_sz,
182 num_kv_heads * head_dim,
183 &cfg.quantization_config,
184 false,
185 comm,
186 kv_shard,
187 vb.pp("k_proj"),
188 )?;
189 let v_proj = ColumnParallelLayer::new_with_shard(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 &cfg.quantization_config,
193 false,
194 comm,
195 kv_shard,
196 vb.pp("v_proj"),
197 )?;
198 let o_proj = RowParallelLayer::new(
199 num_heads * head_dim,
200 hidden_sz,
201 &cfg.quantization_config,
202 false,
203 comm,
204 vb.pp("o_proj"),
205 )?;
206 Ok(Self {
207 q_proj,
208 k_proj,
209 v_proj,
210 o_proj,
211 num_heads: num_heads / comm.world_size(),
212 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
213 head_dim,
214 sliding_window: cfg.sliding_window,
215 paged_attn,
216 sdpa_params: SdpaParams {
217 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
218 cfg.num_key_value_heads,
219 cfg.num_attention_heads,
220 comm,
221 ),
222 softcap: None,
223 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
224 sliding_window: cfg.sliding_window,
225 },
226 })
227 }
228
229 #[allow(clippy::too_many_arguments)]
230 fn forward(
231 &self,
232 xs: &Tensor,
233 attention_mask: Option<&Tensor>,
234 seqlen_offsets: &[usize],
235 kv_cache: &mut Option<(Tensor, Tensor)>,
236 rope_parameter: (&Tensor, &Tensor),
237 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
238 flash_params: &FlashParams,
239 ) -> Result<Tensor> {
240 let (b_sz, q_len, _) = xs.dims3()?;
241
242 let original_dtype = xs.dtype();
243 let mut xs = xs.clone();
244 if let Some(t) = self.q_proj.quantized_act_type() {
245 xs = xs.to_dtype(t)?;
246 }
247 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
248 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
249 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
250 if self.q_proj.quantized_act_type().is_some() {
251 q = q.to_dtype(original_dtype)?;
252 k = k.to_dtype(original_dtype)?;
253 v = v.to_dtype(original_dtype)?;
254 }
255
256 let mut q = q
257 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
258 .transpose(1, 2)?
259 .contiguous()?;
260 let mut k = k
261 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
262 .transpose(1, 2)?
263 .contiguous()?;
264 q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
265 k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
266 let v = v
267 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
268 .transpose(1, 2)?;
269
270 let mut attn_output = match &self.paged_attn {
271 Some(paged_attn) => match metadata {
272 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
273 &q,
274 &k,
275 &v,
276 attention_mask,
277 Some(key_cache),
278 Some(value_cache),
279 input_metadata,
280 &self.sdpa_params,
281 Some(flash_params),
282 )?,
283 None => {
284 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
287 assert!(attention_mask.is_some());
289 paged_attn.forward(
290 &q,
291 &k,
292 &v,
293 attention_mask,
294 None,
295 None,
296 &input_metadata,
297 &self.sdpa_params,
298 Some(flash_params),
299 )?
300 }
301 },
302 None => {
303 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
304 kv_cache,
305 k,
306 v,
307 attention_mask,
308 self.sliding_window,
309 false,
310 )?;
311
312 Sdpa.run_attention(
313 &q,
314 &k,
315 &v,
316 attn_mask.as_ref(),
317 Some(flash_params),
318 &self.sdpa_params,
319 )?
320 }
321 };
322
323 if let Some(t) = self.q_proj.quantized_act_type() {
324 attn_output = attn_output.to_dtype(t)?;
325 }
326 attn_output = if attention_mask.is_some() {
327 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
328 } else {
329 attn_output.reshape((b_sz, q_len, ()))?
330 };
331 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
332 if self.q_proj.quantized_act_type().is_some() {
333 res = res.to_dtype(original_dtype)?;
334 }
335 Ok(res)
336 }
337}
338
339struct DecoderLayer {
340 self_attn: Attention,
341 mlp: Box<dyn MlpLayer>,
342 input_layernorm: RmsNorm,
343 post_attention_layernorm: RmsNorm,
344 rope_parameter: (Tensor, Tensor),
345}
346
347impl DecoderLayer {
348 #[allow(clippy::too_many_arguments)]
349 fn new(
350 cfg: &Config,
351 vb: ShardedVarBuilder,
352 mapper: &dyn DeviceMapper,
353 layer_idx: usize,
354 loading_isq: bool,
355 paged_attn: Option<PagedAttention>,
356 rope_parameter: (Tensor, Tensor),
357 comm: &Arc<mistralrs_quant::Comm>,
358 ) -> Result<Self> {
359 let self_attn = Attention::new(
360 cfg,
361 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
362 paged_attn,
363 comm,
364 )?;
365 let mlp = MLP::new(
366 cfg,
367 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
368 comm,
369 )?;
370 let input_layernorm = RmsNorm::new(
371 cfg.hidden_size,
372 cfg.rms_norm_eps,
373 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
374 )?;
375 let post_attention_layernorm = RmsNorm::new(
376 cfg.hidden_size,
377 cfg.rms_norm_eps,
378 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
379 )?;
380 Ok(Self {
381 self_attn,
382 mlp: Box::new(mlp),
383 input_layernorm,
384 post_attention_layernorm,
385 rope_parameter,
386 })
387 }
388
389 #[allow(clippy::too_many_arguments)]
390 fn forward(
391 &self,
392 xs: &Tensor,
393 attention_mask: Option<&Tensor>,
394 seqlen_offsets: &[usize],
395 kv_cache: &mut Option<(Tensor, Tensor)>,
396 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
397 flash_params: &FlashParams,
398 ) -> Result<Tensor> {
399 let residual = xs;
400 let mut xs = self.input_layernorm.forward(xs)?;
401 xs = self.self_attn.forward(
402 &xs,
403 attention_mask,
404 seqlen_offsets,
405 kv_cache,
406 (&self.rope_parameter.0, &self.rope_parameter.1),
407 metadata,
408 flash_params,
409 )?;
410 xs = (xs + residual)?;
411 let residual = &xs;
412 let xs = self
413 .mlp
414 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
415 residual + xs
416 }
417}
418
419pub struct Model {
420 embed_tokens: candle_nn::Embedding,
421 layers: Vec<DecoderLayer>,
422 norm: RmsNorm,
423 lm_head: Arc<dyn QuantMethod>,
424 sliding_window: Option<usize>,
425 device: Device,
426 cache: EitherCache,
427 max_seq_len: usize,
428 mapper: Box<dyn DeviceMapper + Send + Sync>,
429 cfg: ModelConfigMetadata,
430}
431
432impl Model {
433 pub fn new(
434 cfg: &Config,
435 vb: ShardedVarBuilder,
436 is_gptx: bool,
437 normal_loading_metadata: NormalLoadingMetadata,
438 attention_mechanism: AttentionImplementation,
439 ) -> Result<Self> {
440 let vb_m = vb.pp("model");
441 let vb_lm_head = vb.pp("lm_head");
442 Self::new_inner(
443 cfg,
444 vb_m,
445 vb_lm_head,
446 is_gptx,
447 normal_loading_metadata,
448 attention_mechanism,
449 )
450 }
451
452 pub fn new_inner(
453 cfg: &Config,
454 vb_m: ShardedVarBuilder,
455 vb_lm_head: ShardedVarBuilder,
456 _is_gptx: bool,
457 normal_loading_metadata: NormalLoadingMetadata,
458 attention_mechanism: AttentionImplementation,
459 ) -> Result<Self> {
460 if let Some(ref quant_cfg) = &cfg.quantization_config {
461 tracing::info!(
462 "Using {} quantization: {}.",
463 quant_cfg.name(),
464 quant_cfg.get_bits_name(&vb_m)
465 );
466 }
467 let mapper = normal_loading_metadata.mapper;
468 let embed_tokens = layers::embedding(
469 cfg.vocab_size,
470 cfg.hidden_size,
471 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
472 &cfg.quantization_config,
473 )?;
474 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
475 let vb_l = vb_m.pp("layers");
476 let layers = NiceProgressBar::<_, 'b'>(
477 0..cfg.num_hidden_layers,
478 "Loading repeating layers",
479 &normal_loading_metadata.multi_progress,
480 )
481 .par_iter_if_isq(|layer_idx| {
482 let device = mapper
483 .device_for(layer_idx, false)
484 .unwrap_or(&normal_loading_metadata.real_device);
485 let rope_parameters = OrdinaryRoPE::create_parameters(
486 head_dim,
487 cfg.max_position_embeddings,
488 cfg.rope_theta as f32,
489 vb_m.dtype(),
490 device,
491 )?;
492 let paged_attn = match &attention_mechanism {
493 AttentionImplementation::Eager => None,
494 AttentionImplementation::PagedAttention => {
495 Some(PagedAttention::new(head_dim, device, None)?)
496 }
497 };
498 let comm = mapper.get_comm_for(layer_idx)?;
499 DecoderLayer::new(
500 cfg,
501 vb_l.pp(layer_idx),
502 &*mapper,
503 layer_idx,
504 normal_loading_metadata.loading_isq,
505 paged_attn,
506 rope_parameters,
507 &comm,
508 )
509 })?;
510 let norm = RmsNorm::new(
511 cfg.hidden_size,
512 cfg.rms_norm_eps,
513 mapper.set_nm_device(vb_m.pp("norm"), false),
514 )?;
515 let lm_head = linear_no_bias(
516 cfg.hidden_size,
517 cfg.vocab_size,
518 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
519 )?;
520 Ok(Self {
521 embed_tokens,
522 layers,
523 norm,
524 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
525 sliding_window: cfg.sliding_window,
526 device: normal_loading_metadata.real_device,
527 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
528 max_seq_len: cfg.max_position_embeddings,
529 cfg: ModelConfigMetadata {
530 max_seq_len: cfg.max_position_embeddings,
531 num_layers: cfg.num_hidden_layers,
532 hidden_size: cfg.hidden_size,
533 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535 .max(1),
536 sliding_window: cfg.sliding_window,
537 k_head_dim: cfg.head_dim(),
538 v_head_dim: cfg.head_dim(),
539 },
540 mapper,
541 })
542 }
543
544 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
545 self.embed_tokens.forward(input_ids)
546 }
547
548 pub fn forward(
549 &self,
550 input_ids: &Tensor,
551 seqlen_offsets: &[usize],
552 context_lens: Vec<(usize, usize)>,
553 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
554 flash_params: &FlashParams,
555 ) -> Result<Tensor> {
556 self.forward_embeds(
557 input_ids,
558 self.embed_tokens.forward(input_ids)?,
559 seqlen_offsets,
560 context_lens,
561 metadata,
562 flash_params,
563 )
564 }
565
566 #[allow(clippy::too_many_arguments)]
567 pub fn forward_embeds(
568 &self,
569 input_ids: &Tensor,
570 input_embeds: Tensor,
571 seqlen_offsets: &[usize],
572 context_lens: Vec<(usize, usize)>,
573 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
574 flash_params: &FlashParams,
575 ) -> Result<Tensor> {
576 let mut xs = input_embeds;
577 let mut cache = self.cache.full().lock();
578 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
579 input_ids,
580 metadata
581 .as_ref()
582 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
583 .unwrap_or(&*cache as &dyn PastKvLenCache),
584 self.sliding_window,
585 xs.dtype(),
586 self.cfg.num_attn_heads,
587 )?;
588 let attention_mask = attention_mask.filter(|_| {
589 metadata
590 .as_ref()
591 .map(|(_, meta)| meta.is_first_prompt_chunk)
592 .unwrap_or(true)
593 });
594 for (i, layer) in self.layers.iter().enumerate() {
595 xs = self.mapper.map(xs, i)?;
596 xs = layer.forward(
597 &xs,
598 attention_mask
599 .as_ref()
600 .map(|m| m.to_device(xs.device()).unwrap())
601 .as_ref(),
602 seqlen_offsets,
603 &mut cache[i],
604 metadata
605 .as_ref()
606 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
607 flash_params,
608 )?;
609 }
610 xs = xs.to_device(&self.device)?;
611 xs = xs.apply(&self.norm)?;
612 if let Some(t) = self.lm_head.quantized_act_type() {
613 xs = xs.to_dtype(t)?;
614 }
615 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
616 }
617}
618
619impl IsqModel for Model {
620 fn get_layers(
621 &mut self,
622 ) -> (
623 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
624 &dyn DeviceMapper,
625 ) {
626 let mut tensors = Vec::new();
627 tensors.push((&mut self.lm_head, None));
628 for (i, layer) in self.layers.iter_mut().enumerate() {
629 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
630 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
631 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
632 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
633 tensors.extend(
634 layer
635 .mlp
636 .get_isq_layers()
637 .into_iter()
638 .map(|m| (m, Some(i)))
639 .collect::<Vec<_>>(),
640 );
641 }
642 (tensors, &*self.mapper)
643 }
644
645 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
646 Vec::new()
647 }
648}
649
650impl LLaVALLM for Model {
651 fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
652 self.get_input_embeddings(input_ids)
653 }
654
655 fn forward_input_embed(
656 &self,
657 input_ids: &Tensor,
658 input_embed: Tensor,
659 seqlen_offsets: &[usize],
660 context_lens: Vec<(usize, usize)>,
661 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
662 flash_params: &FlashParams,
663 ) -> Result<Tensor> {
664 self.forward_embeds(
665 input_ids,
666 input_embed,
667 seqlen_offsets,
668 context_lens,
669 metadata,
670 flash_params,
671 )
672 }
673}
674
675impl NormalModel for Model {
676 fn forward(
677 &self,
678 input_ids: &Tensor,
679 seqlen_offsets: &[usize],
680 context_lens: Vec<(usize, usize)>,
681 _position_ids: Vec<usize>,
682 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
683 flash_params: &FlashParams,
684 ) -> Result<Tensor> {
685 self.forward(
686 input_ids,
687 seqlen_offsets,
688 context_lens,
689 metadata,
690 flash_params,
691 )
692 }
693 fn xlora_forward(
694 &self,
695 _input_ids: &Tensor,
696 _input_ids_full: &Tensor,
697 _seqlen_offsets: &[usize],
698 _seqlen_offsets_full: &[usize],
699 _no_kv_cache: bool,
700 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
701 _context_lens: Vec<(usize, usize)>,
702 _position_ids: Vec<usize>,
703 _flash_params: &FlashParams,
704 _flash_params_full: &FlashParams,
705 ) -> Result<Tensor> {
706 unimplemented!()
707 }
708 fn cache(&self) -> &EitherCache {
709 &self.cache
710 }
711 fn cache_mut(&mut self) -> &mut EitherCache {
712 &mut self.cache
713 }
714 fn device(&self) -> &Device {
715 &self.device
716 }
717 fn is_xlora(&self) -> bool {
718 false
719 }
720 fn max_seq_len(&self) -> usize {
721 self.max_seq_len
722 }
723 fn config(&self) -> &ModelConfigMetadata {
724 &self.cfg
725 }
726}
727
728impl AnyMoeBaseModelMixin for Model {
729 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
730 let mut mlps = Vec::new();
731 for layer in &self.layers {
732 mlps.push(&*layer.mlp);
733 }
734 mlps
735 }
736 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
737 let mut mlps = Vec::new();
738 for layer in &mut self.layers {
739 mlps.push(&mut layer.mlp);
740 }
741 mlps
742 }
743 fn create_anymoe_layers(
744 &mut self,
745 additional_vbs: Vec<ShardedVarBuilder>,
746 config: AnyMoeConfig,
747 (prefix, mlp): (String, String),
748 mut layers: Vec<usize>,
749 expert_type: AnyMoeExpertType,
750 gate_vb: Option<ShardedVarBuilder>,
751 ) -> Result<()> {
752 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
753 if layers.is_empty() {
754 layers = (0..self.layers.len()).collect::<Vec<_>>();
755 }
756 for _ in 0..layers.len() {
757 experts.push(Vec::new());
758 }
759 for vb in additional_vbs {
760 let vb = vb.pp(&prefix);
761 for (layer, row) in experts.iter_mut().enumerate() {
762 if !layers.contains(&layer) {
763 continue;
764 }
765
766 let intermediate_size = self.layers[layer].mlp.get_params()[1];
767 let hidden_size = self.layers[layer].mlp.get_params()[0];
768 match expert_type {
769 AnyMoeExpertType::FineTuned => {
770 let (dtype, device) = self.layers[layer].mlp.dtype_device();
771 row.push(Box::new(MLP::new(
772 &Config {
773 intermediate_size: self.layers[layer].mlp.get_params()[1],
774 hidden_size: self.layers[layer].mlp.get_params()[0],
775 ..Default::default()
776 },
777 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
778 &self.mapper.get_comm_for(layer)?,
779 )?));
780 }
781 AnyMoeExpertType::LoraAdapter {
782 rank,
783 alpha,
784 ref target_modules,
785 } => {
786 let vb_mlp = vb.pp(layer).pp(&mlp);
787
788 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
789 Some(get_delta_from_lora_ab!(
790 vb_mlp,
791 rank,
792 alpha,
793 (hidden_size, intermediate_size),
794 "gate_proj"
795 ))
796 } else {
797 None
798 };
799 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
800 Some(get_delta_from_lora_ab!(
801 vb_mlp,
802 rank,
803 alpha,
804 (hidden_size, intermediate_size),
805 "up_proj"
806 ))
807 } else {
808 None
809 };
810 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
811 Some(get_delta_from_lora_ab!(
812 vb_mlp,
813 rank,
814 alpha,
815 (intermediate_size, hidden_size),
816 "down_proj"
817 ))
818 } else {
819 None
820 };
821
822 row.push(self.layers[layer].mlp.new_added_delta(vec![
823 gate_proj_delta,
824 up_proj_delta,
825 down_proj_delta,
826 ])?);
827 }
828 }
829 }
830 }
831 for (layer, expert) in layers.into_iter().zip(experts) {
832 let mut experts_all = vec![self.layers[layer].mlp.clone()];
833 experts_all.extend(expert);
834 let (dtype, device) = self.layers[layer].mlp.dtype_device();
835 self.layers[layer].mlp = Box::new(MoeMlp::new(
836 experts_all,
837 config.clone(),
838 dtype,
839 &device,
840 layer,
841 gate_vb.as_ref(),
842 )?);
843 }
844 Ok(())
845 }
846 fn amoe_supported(&self) -> bool {
847 true
848 }
849}