1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
9 UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{self, linear_no_bias, Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
18 layers_masker::PastKvLenCache,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20 pipeline::{
21 extract_logits,
22 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
24 },
25 utils::progress::NiceProgressBar,
26 AnyMoeConfig, AnyMoeExpertType,
27};
28
29use super::{LLaVALLM, OrdinaryRoPE};
30use crate::models::mistral::Config;
31
32#[derive(Clone)]
33#[allow(clippy::upper_case_acronyms)]
34struct MLP {
35 gate_proj: Arc<dyn QuantMethod>,
36 up_proj: Arc<dyn QuantMethod>,
37 down_proj: Arc<dyn QuantMethod>,
38 act_fn: Activation,
39 params: Vec<usize>,
40}
41
42impl MLP {
43 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
44 let hidden_sz = cfg.hidden_size;
45 let intermediate_sz = cfg.intermediate_size;
46 let gate_proj = ColumnParallelLayer::new(
47 hidden_sz,
48 intermediate_sz,
49 &cfg.quantization_config,
50 false,
51 comm,
52 vb.pp("gate_proj"),
53 )?;
54 let up_proj = ColumnParallelLayer::new(
55 hidden_sz,
56 intermediate_sz,
57 &cfg.quantization_config,
58 false,
59 comm,
60 vb.pp("up_proj"),
61 )?;
62 let down_proj = RowParallelLayer::new(
63 intermediate_sz,
64 hidden_sz,
65 &cfg.quantization_config,
66 false,
67 comm,
68 vb.pp("down_proj"),
69 )?;
70 Ok(Self {
71 gate_proj,
72 up_proj,
73 down_proj,
74 act_fn: cfg.hidden_act,
75 params: vec![hidden_sz, intermediate_sz],
76 })
77 }
78}
79
80impl AnyMoeTrainableLayer for MLP {}
81
82impl MlpLayer for MLP {
83 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
84 let original_dtype = xs.dtype();
85 let mut xs = xs.clone();
86 if let Some(t) = self.gate_proj.quantized_act_type() {
87 xs = xs.to_dtype(t)?;
88 }
89 let lhs = MatMul
90 .qmethod_matmul(&xs, &*self.gate_proj)?
91 .apply(&self.act_fn)?;
92 let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
93 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?;
94 if self.gate_proj.quantized_act_type().is_some() {
95 res = res.to_dtype(original_dtype)?;
96 }
97 Ok(res)
98 }
99 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
100 vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj]
101 }
102 fn clone(&self) -> Box<dyn MlpLayer> {
103 Box::new(Clone::clone(self))
104 }
105 fn get_params(&self) -> &[usize] {
106 &self.params
107 }
108 fn hidden_act(&self) -> Activation {
109 self.act_fn
110 }
111 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
113 let gate_proj = if let Some(ref delta) = deltas[0] {
114 self.gate_proj.add_delta_w(delta)?
115 } else {
116 self.gate_proj.clone()
117 };
118 let up_proj = if let Some(ref delta) = deltas[1] {
119 self.up_proj.add_delta_w(delta)?
120 } else {
121 self.up_proj.clone()
122 };
123 let down_proj = if let Some(ref delta) = deltas[2] {
124 self.down_proj.add_delta_w(delta)?
125 } else {
126 self.down_proj.clone()
127 };
128
129 Ok(Box::new(Self {
130 gate_proj,
131 up_proj,
132 down_proj,
133 act_fn: self.act_fn,
134 params: self.params.clone(),
135 }))
136 }
137
138 fn dtype_device(&self) -> (DType, Device) {
139 self.gate_proj.dtype_and_device()
140 }
141}
142
143struct Attention {
144 q_proj: Arc<dyn QuantMethod>,
145 k_proj: Arc<dyn QuantMethod>,
146 v_proj: Arc<dyn QuantMethod>,
147 o_proj: Arc<dyn QuantMethod>,
148 num_heads: usize,
149 num_kv_heads: usize,
150 head_dim: usize,
151 sliding_window: Option<usize>,
152 paged_attn: Option<PagedAttention>,
153 sdpa_params: SdpaParams,
154}
155
156impl Attention {
157 fn new(
158 cfg: &Config,
159 vb: ShardedVarBuilder,
160 paged_attn: Option<PagedAttention>,
161 comm: &Arc<mistralrs_quant::Comm>,
162 ) -> Result<Self> {
163 let hidden_sz = cfg.hidden_size;
164 let num_heads = cfg.num_attention_heads;
165 let num_kv_heads = cfg.num_key_value_heads;
166 let head_dim = cfg.head_dim();
167 let q_proj = ColumnParallelLayer::new(
168 hidden_sz,
169 num_heads * head_dim,
170 &cfg.quantization_config,
171 false,
172 comm,
173 vb.pp("q_proj"),
174 )?;
175 let kv_shard = mistralrs_quant::compute_kv_shard(
176 cfg.num_key_value_heads,
177 cfg.hidden_size / cfg.num_attention_heads,
178 comm,
179 );
180 let k_proj = ColumnParallelLayer::new_with_shard(
181 hidden_sz,
182 num_kv_heads * head_dim,
183 &cfg.quantization_config,
184 false,
185 comm,
186 kv_shard,
187 vb.pp("k_proj"),
188 )?;
189 let v_proj = ColumnParallelLayer::new_with_shard(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 &cfg.quantization_config,
193 false,
194 comm,
195 kv_shard,
196 vb.pp("v_proj"),
197 )?;
198 let o_proj = RowParallelLayer::new(
199 num_heads * head_dim,
200 hidden_sz,
201 &cfg.quantization_config,
202 false,
203 comm,
204 vb.pp("o_proj"),
205 )?;
206 Ok(Self {
207 q_proj,
208 k_proj,
209 v_proj,
210 o_proj,
211 num_heads: num_heads / comm.world_size(),
212 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
213 head_dim,
214 sliding_window: cfg.sliding_window,
215 paged_attn,
216 sdpa_params: SdpaParams {
217 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
218 cfg.num_key_value_heads,
219 cfg.num_attention_heads,
220 comm,
221 ),
222 use_flash_attn: cfg.use_flash_attn,
223 softcap: None,
224 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
225 sliding_window: cfg.sliding_window,
226 },
227 })
228 }
229
230 #[allow(clippy::too_many_arguments)]
231 fn forward(
232 &self,
233 xs: &Tensor,
234 attention_mask: Option<&Tensor>,
235 seqlen_offsets: &[usize],
236 kv_cache: &mut Option<(Tensor, Tensor)>,
237 rope_parameter: (&Tensor, &Tensor),
238 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
239 flash_params: &FlashParams,
240 ) -> Result<Tensor> {
241 let (b_sz, q_len, _) = xs.dims3()?;
242
243 let original_dtype = xs.dtype();
244 let mut xs = xs.clone();
245 if let Some(t) = self.q_proj.quantized_act_type() {
246 xs = xs.to_dtype(t)?;
247 }
248 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
249 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
250 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
251 if self.q_proj.quantized_act_type().is_some() {
252 q = q.to_dtype(original_dtype)?;
253 k = k.to_dtype(original_dtype)?;
254 v = v.to_dtype(original_dtype)?;
255 }
256
257 let mut q = q
258 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
259 .transpose(1, 2)?
260 .contiguous()?;
261 let mut k = k
262 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
263 .transpose(1, 2)?
264 .contiguous()?;
265 q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
266 k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
267 let v = v
268 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
269 .transpose(1, 2)?;
270
271 let mut attn_output = match &self.paged_attn {
272 Some(paged_attn) => match metadata {
273 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
274 &q,
275 &k,
276 &v,
277 attention_mask,
278 Some(key_cache),
279 Some(value_cache),
280 input_metadata,
281 &self.sdpa_params,
282 Some(flash_params),
283 )?,
284 None => {
285 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
288 assert!(attention_mask.is_some());
290 paged_attn.forward(
291 &q,
292 &k,
293 &v,
294 attention_mask,
295 None,
296 None,
297 &input_metadata,
298 &self.sdpa_params,
299 Some(flash_params),
300 )?
301 }
302 },
303 None => {
304 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
305 kv_cache,
306 k,
307 v,
308 attention_mask,
309 self.sliding_window,
310 false,
311 )?;
312
313 Sdpa.run_attention(
314 &q,
315 &k,
316 &v,
317 attn_mask.as_ref(),
318 Some(flash_params),
319 &self.sdpa_params,
320 )?
321 }
322 };
323
324 if let Some(t) = self.q_proj.quantized_act_type() {
325 attn_output = attn_output.to_dtype(t)?;
326 }
327 attn_output = if attention_mask.is_some() {
328 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
329 } else {
330 attn_output.reshape((b_sz, q_len, ()))?
331 };
332 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
333 if self.q_proj.quantized_act_type().is_some() {
334 res = res.to_dtype(original_dtype)?;
335 }
336 Ok(res)
337 }
338}
339
340struct DecoderLayer {
341 self_attn: Attention,
342 mlp: Box<dyn MlpLayer>,
343 input_layernorm: RmsNorm,
344 post_attention_layernorm: RmsNorm,
345 rope_parameter: (Tensor, Tensor),
346}
347
348impl DecoderLayer {
349 #[allow(clippy::too_many_arguments)]
350 fn new(
351 cfg: &Config,
352 vb: ShardedVarBuilder,
353 mapper: &dyn DeviceMapper,
354 layer_idx: usize,
355 loading_isq: bool,
356 paged_attn: Option<PagedAttention>,
357 rope_parameter: (Tensor, Tensor),
358 comm: &Arc<mistralrs_quant::Comm>,
359 ) -> Result<Self> {
360 let self_attn = Attention::new(
361 cfg,
362 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
363 paged_attn,
364 comm,
365 )?;
366 let mlp = MLP::new(
367 cfg,
368 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
369 comm,
370 )?;
371 let input_layernorm = RmsNorm::new(
372 cfg.hidden_size,
373 cfg.rms_norm_eps,
374 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
375 )?;
376 let post_attention_layernorm = RmsNorm::new(
377 cfg.hidden_size,
378 cfg.rms_norm_eps,
379 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
380 )?;
381 Ok(Self {
382 self_attn,
383 mlp: Box::new(mlp),
384 input_layernorm,
385 post_attention_layernorm,
386 rope_parameter,
387 })
388 }
389
390 #[allow(clippy::too_many_arguments)]
391 fn forward(
392 &self,
393 xs: &Tensor,
394 attention_mask: Option<&Tensor>,
395 seqlen_offsets: &[usize],
396 kv_cache: &mut Option<(Tensor, Tensor)>,
397 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
398 flash_params: &FlashParams,
399 ) -> Result<Tensor> {
400 let residual = xs;
401 let mut xs = self.input_layernorm.forward(xs)?;
402 xs = self.self_attn.forward(
403 &xs,
404 attention_mask,
405 seqlen_offsets,
406 kv_cache,
407 (&self.rope_parameter.0, &self.rope_parameter.1),
408 metadata,
409 flash_params,
410 )?;
411 xs = (xs + residual)?;
412 let residual = &xs;
413 let xs = self
414 .mlp
415 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
416 residual + xs
417 }
418}
419
420pub struct Model {
421 embed_tokens: candle_nn::Embedding,
422 layers: Vec<DecoderLayer>,
423 norm: RmsNorm,
424 lm_head: Arc<dyn QuantMethod>,
425 sliding_window: Option<usize>,
426 device: Device,
427 cache: EitherCache,
428 max_seq_len: usize,
429 mapper: Box<dyn DeviceMapper + Send + Sync>,
430 cfg: ModelConfigMetadata,
431}
432
433impl Model {
434 pub fn new(
435 cfg: &Config,
436 vb: ShardedVarBuilder,
437 is_gptx: bool,
438 normal_loading_metadata: NormalLoadingMetadata,
439 attention_mechanism: AttentionImplementation,
440 ) -> Result<Self> {
441 let vb_m = vb.pp("model");
442 let vb_lm_head = vb.pp("lm_head");
443 Self::new_inner(
444 cfg,
445 vb_m,
446 vb_lm_head,
447 is_gptx,
448 normal_loading_metadata,
449 attention_mechanism,
450 )
451 }
452
453 pub fn new_inner(
454 cfg: &Config,
455 vb_m: ShardedVarBuilder,
456 vb_lm_head: ShardedVarBuilder,
457 _is_gptx: bool,
458 normal_loading_metadata: NormalLoadingMetadata,
459 attention_mechanism: AttentionImplementation,
460 ) -> Result<Self> {
461 if let Some(ref quant_cfg) = &cfg.quantization_config {
462 tracing::info!(
463 "Using {} quantization: {}.",
464 quant_cfg.name(),
465 quant_cfg.get_bits_name(&vb_m)
466 );
467 }
468 let mapper = normal_loading_metadata.mapper;
469 let embed_tokens = layers::embedding(
470 cfg.vocab_size,
471 cfg.hidden_size,
472 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
473 &cfg.quantization_config,
474 )?;
475 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
476 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
477 let vb_l = vb_m.pp("layers");
478 for layer_idx in NiceProgressBar::<_, 'b'>(
479 0..cfg.num_hidden_layers,
480 "Loading repeating layers",
481 &normal_loading_metadata.multi_progress,
482 ) {
483 let device = mapper
484 .device_for(layer_idx, false)
485 .unwrap_or(&normal_loading_metadata.real_device);
486 let rope_parameters = OrdinaryRoPE::create_parameters(
487 head_dim,
488 cfg.max_position_embeddings,
489 cfg.rope_theta as f32,
490 vb_m.dtype(),
491 device,
492 )?;
493 let paged_attn = match &attention_mechanism {
494 AttentionImplementation::Eager => None,
495 AttentionImplementation::PagedAttention => {
496 Some(PagedAttention::new(head_dim, device, None)?)
497 }
498 };
499 let comm = mapper.get_comm_for(layer_idx)?;
500 let layer = DecoderLayer::new(
501 cfg,
502 vb_l.pp(layer_idx),
503 &*mapper,
504 layer_idx,
505 normal_loading_metadata.loading_isq,
506 paged_attn,
507 rope_parameters,
508 &comm,
509 )?;
510 layers.push(layer)
511 }
512 let norm = RmsNorm::new(
513 cfg.hidden_size,
514 cfg.rms_norm_eps,
515 mapper.set_nm_device(vb_m.pp("norm"), false),
516 )?;
517 let lm_head = linear_no_bias(
518 cfg.hidden_size,
519 cfg.vocab_size,
520 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
521 )?;
522 Ok(Self {
523 embed_tokens,
524 layers,
525 norm,
526 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
527 sliding_window: cfg.sliding_window,
528 device: normal_loading_metadata.real_device,
529 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
530 max_seq_len: cfg.max_position_embeddings,
531 cfg: ModelConfigMetadata {
532 max_seq_len: cfg.max_position_embeddings,
533 num_layers: cfg.num_hidden_layers,
534 hidden_size: cfg.hidden_size,
535 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
536 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
537 .max(1),
538 sliding_window: cfg.sliding_window,
539 k_head_dim: cfg.head_dim(),
540 v_head_dim: cfg.head_dim(),
541 },
542 mapper,
543 })
544 }
545
546 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
547 self.embed_tokens.forward(input_ids)
548 }
549
550 pub fn forward(
551 &self,
552 input_ids: &Tensor,
553 seqlen_offsets: &[usize],
554 context_lens: Vec<(usize, usize)>,
555 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
556 flash_params: &FlashParams,
557 ) -> Result<Tensor> {
558 self.forward_embeds(
559 input_ids,
560 self.embed_tokens.forward(input_ids)?,
561 seqlen_offsets,
562 context_lens,
563 metadata,
564 flash_params,
565 )
566 }
567
568 #[allow(clippy::too_many_arguments)]
569 pub fn forward_embeds(
570 &self,
571 input_ids: &Tensor,
572 input_embeds: Tensor,
573 seqlen_offsets: &[usize],
574 context_lens: Vec<(usize, usize)>,
575 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
576 flash_params: &FlashParams,
577 ) -> Result<Tensor> {
578 let mut xs = input_embeds;
579 let mut cache = self.cache.full().lock();
580 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
581 input_ids,
582 metadata
583 .as_ref()
584 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
585 .unwrap_or(&*cache as &dyn PastKvLenCache),
586 self.sliding_window,
587 xs.dtype(),
588 self.cfg.num_attn_heads,
589 )?;
590 let attention_mask = attention_mask.filter(|_| {
591 metadata
592 .as_ref()
593 .map(|(_, meta)| meta.is_first_prompt_chunk)
594 .unwrap_or(true)
595 });
596 for (i, layer) in self.layers.iter().enumerate() {
597 xs = self.mapper.map(xs, i)?;
598 xs = layer.forward(
599 &xs,
600 attention_mask
601 .as_ref()
602 .map(|m| m.to_device(xs.device()).unwrap())
603 .as_ref(),
604 seqlen_offsets,
605 &mut cache[i],
606 metadata
607 .as_ref()
608 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
609 flash_params,
610 )?;
611 }
612 xs = xs.to_device(&self.device)?;
613 xs = xs.apply(&self.norm)?;
614 if let Some(t) = self.lm_head.quantized_act_type() {
615 xs = xs.to_dtype(t)?;
616 }
617 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
618 }
619}
620
621impl IsqModel for Model {
622 fn get_layers(
623 &mut self,
624 ) -> (
625 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
626 &dyn DeviceMapper,
627 ) {
628 let mut tensors = Vec::new();
629 tensors.push((&mut self.lm_head, None));
630 for (i, layer) in self.layers.iter_mut().enumerate() {
631 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
632 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
633 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
634 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
635 tensors.extend(
636 layer
637 .mlp
638 .get_isq_layers()
639 .into_iter()
640 .map(|m| (m, Some(i)))
641 .collect::<Vec<_>>(),
642 );
643 }
644 (tensors, &*self.mapper)
645 }
646
647 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
648 Vec::new()
649 }
650}
651
652impl LLaVALLM for Model {
653 fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
654 self.get_input_embeddings(input_ids)
655 }
656
657 fn forward_input_embed(
658 &self,
659 input_ids: &Tensor,
660 input_embed: Tensor,
661 seqlen_offsets: &[usize],
662 context_lens: Vec<(usize, usize)>,
663 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
664 flash_params: &FlashParams,
665 ) -> Result<Tensor> {
666 self.forward_embeds(
667 input_ids,
668 input_embed,
669 seqlen_offsets,
670 context_lens,
671 metadata,
672 flash_params,
673 )
674 }
675}
676
677impl NormalModel for Model {
678 fn forward(
679 &self,
680 input_ids: &Tensor,
681 seqlen_offsets: &[usize],
682 context_lens: Vec<(usize, usize)>,
683 _position_ids: Vec<usize>,
684 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
685 flash_params: &FlashParams,
686 ) -> Result<Tensor> {
687 self.forward(
688 input_ids,
689 seqlen_offsets,
690 context_lens,
691 metadata,
692 flash_params,
693 )
694 }
695 fn xlora_forward(
696 &self,
697 _input_ids: &Tensor,
698 _input_ids_full: &Tensor,
699 _seqlen_offsets: &[usize],
700 _seqlen_offsets_full: &[usize],
701 _no_kv_cache: bool,
702 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
703 _context_lens: Vec<(usize, usize)>,
704 _position_ids: Vec<usize>,
705 _flash_params: &FlashParams,
706 _flash_params_full: &FlashParams,
707 ) -> Result<Tensor> {
708 unimplemented!()
709 }
710 fn cache(&self) -> &EitherCache {
711 &self.cache
712 }
713 fn cache_mut(&mut self) -> &mut EitherCache {
714 &mut self.cache
715 }
716 fn device(&self) -> &Device {
717 &self.device
718 }
719 fn is_xlora(&self) -> bool {
720 false
721 }
722 fn max_seq_len(&self) -> usize {
723 self.max_seq_len
724 }
725 fn config(&self) -> &ModelConfigMetadata {
726 &self.cfg
727 }
728}
729
730impl AnyMoeBaseModelMixin for Model {
731 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
732 let mut mlps = Vec::new();
733 for layer in &self.layers {
734 mlps.push(&*layer.mlp);
735 }
736 mlps
737 }
738 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
739 let mut mlps = Vec::new();
740 for layer in &mut self.layers {
741 mlps.push(&mut layer.mlp);
742 }
743 mlps
744 }
745 fn create_anymoe_layers(
746 &mut self,
747 additional_vbs: Vec<ShardedVarBuilder>,
748 config: AnyMoeConfig,
749 (prefix, mlp): (String, String),
750 mut layers: Vec<usize>,
751 expert_type: AnyMoeExpertType,
752 gate_vb: Option<ShardedVarBuilder>,
753 ) -> Result<()> {
754 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
755 if layers.is_empty() {
756 layers = (0..self.layers.len()).collect::<Vec<_>>();
757 }
758 for _ in 0..layers.len() {
759 experts.push(Vec::new());
760 }
761 for vb in additional_vbs {
762 let vb = vb.pp(&prefix);
763 for (layer, row) in experts.iter_mut().enumerate() {
764 if !layers.contains(&layer) {
765 continue;
766 }
767
768 let intermediate_size = self.layers[layer].mlp.get_params()[1];
769 let hidden_size = self.layers[layer].mlp.get_params()[0];
770 match expert_type {
771 AnyMoeExpertType::FineTuned => {
772 let (dtype, device) = self.layers[layer].mlp.dtype_device();
773 row.push(Box::new(MLP::new(
774 &Config {
775 intermediate_size: self.layers[layer].mlp.get_params()[1],
776 hidden_size: self.layers[layer].mlp.get_params()[0],
777 ..Default::default()
778 },
779 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
780 &self.mapper.get_comm_for(layer)?,
781 )?));
782 }
783 AnyMoeExpertType::LoraAdapter {
784 rank,
785 alpha,
786 ref target_modules,
787 } => {
788 let vb_mlp = vb.pp(layer).pp(&mlp);
789
790 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
791 Some(get_delta_from_lora_ab!(
792 vb_mlp,
793 rank,
794 alpha,
795 (hidden_size, intermediate_size),
796 "gate_proj"
797 ))
798 } else {
799 None
800 };
801 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
802 Some(get_delta_from_lora_ab!(
803 vb_mlp,
804 rank,
805 alpha,
806 (hidden_size, intermediate_size),
807 "up_proj"
808 ))
809 } else {
810 None
811 };
812 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
813 Some(get_delta_from_lora_ab!(
814 vb_mlp,
815 rank,
816 alpha,
817 (intermediate_size, hidden_size),
818 "down_proj"
819 ))
820 } else {
821 None
822 };
823
824 row.push(self.layers[layer].mlp.new_added_delta(vec![
825 gate_proj_delta,
826 up_proj_delta,
827 down_proj_delta,
828 ])?);
829 }
830 }
831 }
832 }
833 for (layer, expert) in layers.into_iter().zip(experts) {
834 let mut experts_all = vec![self.layers[layer].mlp.clone()];
835 experts_all.extend(expert);
836 let (dtype, device) = self.layers[layer].mlp.dtype_device();
837 self.layers[layer].mlp = Box::new(MoeMlp::new(
838 experts_all,
839 config.clone(),
840 dtype,
841 &device,
842 layer,
843 gate_vb.as_ref(),
844 )?);
845 }
846 Ok(())
847 }
848 fn amoe_supported(&self) -> bool {
849 true
850 }
851}