1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
7use std::{collections::HashMap, sync::Arc};
8
9use crate::{
10 amoe::{
11 AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
12 MoeMlp,
13 },
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{
18 embedding, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19 PhiRotaryEmbedding, RmsNorm, Sdpa,
20 },
21 layers_masker::PastKvLenCache,
22 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23 pipeline::{
24 extract_logits,
25 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27 },
28 serde_default_fn,
29 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
36pub struct Config {
37 pub vocab_size: usize,
38 pub hidden_act: Activation,
39 pub hidden_size: usize,
40 pub intermediate_size: usize,
41 pub num_hidden_layers: usize,
42 pub num_attention_heads: usize,
43 pub num_key_value_heads: usize,
44 pub rms_norm_eps: f64,
45 pub rope_theta: f64,
46 pub bos_token_id: Option<u32>,
47 pub eos_token_id: Option<u32>,
48 pub rope_scaling: Option<PhiRopeScalingConfig>,
49 pub max_position_embeddings: usize,
50 pub use_flash_attn: bool,
51 pub sliding_window: Option<usize>,
52 pub original_max_position_embeddings: usize,
53 pub quantization_config: Option<QuantizedConfig>,
54 #[serde(default = "word_emb_default")]
55 pub tie_word_embeddings: bool,
56 pub partial_rotary_factor: Option<f64>,
57}
58
59impl From<Config> for PhiRopeConfig {
60 fn from(val: Config) -> Self {
61 PhiRopeConfig {
62 rope_scaling: val.rope_scaling,
63 max_position_embeddings: val.max_position_embeddings,
64 original_max_position_embeddings: val.original_max_position_embeddings,
65 rope_theta: val.rope_theta,
66 head_dim: val.hidden_size / val.num_attention_heads,
67 partial_rotary_factor: val.partial_rotary_factor,
68 }
69 }
70}
71
72impl Config {
73 pub fn head_dim(&self) -> usize {
74 self.hidden_size / self.num_attention_heads
75 }
76}
77
78struct Attention {
79 qkv_proj: Arc<dyn QuantMethod>,
80 o_proj: Arc<dyn QuantMethod>,
81 num_heads: usize,
82 num_kv_heads: usize,
83 head_dim: usize,
84 rotary_emb: Arc<PhiRotaryEmbedding>,
85 paged_attn: Option<PagedAttention>,
86 sdpa_params: SdpaParams,
87}
88
89impl Attention {
90 fn new(
91 rotary_emb: Arc<PhiRotaryEmbedding>,
92 cfg: &Config,
93 vb: ShardedVarBuilder,
94 paged_attn: Option<PagedAttention>,
95 ) -> Result<Self> {
96 let num_heads = cfg.num_attention_heads;
97 let num_kv_heads = cfg.num_key_value_heads;
98 let head_dim = cfg.head_dim();
99 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
100
101 let qkv_proj = mistralrs_quant::linear_no_bias(
103 cfg.hidden_size,
104 op_size,
105 &cfg.quantization_config,
106 vb.pp("qkv_proj"),
107 )?;
108
109 let o_proj = mistralrs_quant::linear_no_bias(
110 num_heads * head_dim,
111 cfg.hidden_size,
112 &cfg.quantization_config,
113 vb.pp("o_proj"),
114 )?;
115
116 Ok(Self {
117 qkv_proj,
118 o_proj,
119 rotary_emb,
120 num_heads,
121 num_kv_heads,
122 head_dim,
123 paged_attn,
124 sdpa_params: SdpaParams {
125 n_kv_groups: num_heads / num_kv_heads,
126 use_flash_attn: cfg.use_flash_attn,
127 softcap: None,
128 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
129 sliding_window: cfg.sliding_window,
130 },
131 })
132 }
133
134 #[allow(clippy::too_many_arguments)]
135 fn forward(
136 &self,
137 xs: &Tensor,
138 attention_mask: Option<&Tensor>,
139 seqlen_offsets: &[usize],
140 position_ids: &[usize],
141 kv_cache: &mut KvCache,
142 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
143 flash_params: &FlashParams,
144 ) -> Result<Tensor> {
145 let (b_sz, q_len, _) = xs.dims3()?;
146
147 let original_dtype = xs.dtype();
148 let mut xs = xs.clone();
149 if let Some(t) = self.qkv_proj.quantized_act_type() {
150 xs = xs.to_dtype(t)?;
151 }
152 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
153 if self.qkv_proj.quantized_act_type().is_some() {
154 qkv = qkv.to_dtype(original_dtype)?;
155 }
156 let query_pos = self.num_heads * self.head_dim;
157 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
158 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
159 let v = qkv.narrow(
160 D::Minus1,
161 query_pos + self.num_kv_heads * self.head_dim,
162 self.num_kv_heads * self.head_dim,
163 )?;
164
165 let (q, k, v) = if q_len != 1 {
166 let q = q
167 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
168 .transpose(1, 2)?;
169 let k = k
170 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
171 .transpose(1, 2)?;
172 let v = v
173 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174 .transpose(1, 2)?;
175 (q, k, v)
176 } else {
177 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
178 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180 (q, k, v)
181 };
182
183 let (q, k) = self
184 .rotary_emb
185 .forward(&q, &k, seqlen_offsets, position_ids)?;
186
187 let mut attn_output = match &self.paged_attn {
188 Some(paged_attn) => match metadata {
189 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
190 &q,
191 &k.contiguous()?,
192 &v.contiguous()?,
193 attention_mask,
194 Some(key_cache),
195 Some(value_cache),
196 input_metadata,
197 &self.sdpa_params,
198 Some(flash_params),
199 )?,
200 None => {
201 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
204 assert!(attention_mask.is_some());
206 paged_attn.forward(
207 &q,
208 &k.contiguous()?,
209 &v.contiguous()?,
210 attention_mask,
211 None,
212 None,
213 &input_metadata,
214 &self.sdpa_params,
215 Some(flash_params),
216 )?
217 }
218 },
219 _ => {
220 let (k, v) = kv_cache.append(&k, &v)?;
221
222 Sdpa.run_attention(
223 &q,
224 &k,
225 &v,
226 attention_mask,
227 Some(flash_params),
228 &self.sdpa_params,
229 )?
230 }
231 };
232
233 if let Some(t) = self.qkv_proj.quantized_act_type() {
234 attn_output = attn_output.to_dtype(t)?;
235 }
236 attn_output = if attention_mask.is_some() {
237 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
238 } else {
239 attn_output.reshape((b_sz, q_len, ()))?
240 };
241 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
242 if self.qkv_proj.quantized_act_type().is_some() {
243 res = res.to_dtype(original_dtype)?;
244 }
245 Ok(res)
246 }
247}
248
249#[derive(Clone)]
250struct Mlp {
251 gate_up_proj: Arc<dyn QuantMethod>,
252 down_proj: Arc<dyn QuantMethod>,
253 act_fn: Activation,
254 i_size: usize,
255 params: Vec<usize>,
256}
257
258impl Mlp {
259 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
260 let hidden_size = cfg.hidden_size;
261 let i_size = cfg.intermediate_size;
262
263 let gate_up_proj = mistralrs_quant::linear_no_bias(
265 hidden_size,
266 2 * i_size,
267 &cfg.quantization_config,
268 vb.pp("gate_up_proj"),
269 )?;
270
271 let down_proj = mistralrs_quant::linear_no_bias(
272 i_size,
273 hidden_size,
274 &cfg.quantization_config,
275 vb.pp("down_proj"),
276 )?;
277
278 Ok(Self {
279 gate_up_proj,
280 down_proj,
281 act_fn: cfg.hidden_act,
282 i_size,
283 params: vec![hidden_size, i_size],
284 })
285 }
286}
287
288impl AnyMoeTrainableLayer for Mlp {}
289
290impl MlpLayer for Mlp {
291 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
292 let original_dtype = xs.dtype();
293 let mut xs = xs.clone();
294 if let Some(t) = self.gate_up_proj.quantized_act_type() {
295 xs = xs.to_dtype(t)?;
296 }
297 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
298 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
299 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
300 let up_states = (up_states * gate.apply(&self.act_fn))?;
301 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
302 if self.gate_up_proj.quantized_act_type().is_some() {
303 res = res.to_dtype(original_dtype)?;
304 }
305 Ok(res)
306 }
307 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
308 vec![&mut self.gate_up_proj, &mut self.down_proj]
309 }
310 fn clone(&self) -> Box<dyn MlpLayer> {
311 Box::new(Clone::clone(self))
312 }
313 fn get_params(&self) -> &[usize] {
314 &self.params
315 }
316 fn hidden_act(&self) -> Activation {
317 self.act_fn
318 }
319 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
321 let new_gate_up = if let Some(ref delta) = deltas[0] {
322 self.gate_up_proj.add_delta_w(delta)?
323 } else {
324 self.gate_up_proj.clone()
325 };
326 let new_down = if let Some(ref delta) = deltas[1] {
327 self.down_proj.add_delta_w(delta)?
328 } else {
329 self.down_proj.clone()
330 };
331
332 Ok(Box::new(Self {
333 gate_up_proj: new_gate_up,
334 down_proj: new_down,
335 act_fn: self.act_fn,
336 i_size: self.i_size,
337 params: self.params.clone(),
338 }))
339 }
340
341 fn dtype_device(&self) -> (DType, Device) {
342 self.gate_up_proj.dtype_and_device()
343 }
344}
345
346struct DecoderLayer {
347 self_attn: Attention,
348 mlp: Box<dyn MlpLayer>,
349 input_layernorm: RmsNorm,
350 post_attention_layernorm: RmsNorm,
351}
352
353impl DecoderLayer {
354 fn new(
355 rotary_emb: Arc<PhiRotaryEmbedding>,
356 cfg: &Config,
357 vb: ShardedVarBuilder,
358 mapper: &dyn DeviceMapper,
359 layer_idx: usize,
360 loading_isq: bool,
361 paged_attn: Option<PagedAttention>,
362 ) -> Result<Self> {
363 let self_attn = Attention::new(
364 rotary_emb,
365 cfg,
366 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
367 paged_attn,
368 )?;
369 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
370 let input_layernorm = RmsNorm::new(
371 cfg.hidden_size,
372 cfg.rms_norm_eps,
373 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
374 )?;
375 let post_attention_layernorm = RmsNorm::new(
376 cfg.hidden_size,
377 cfg.rms_norm_eps,
378 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
379 )?;
380 Ok(Self {
381 self_attn,
382 mlp: Box::new(mlp),
383 input_layernorm,
384 post_attention_layernorm,
385 })
386 }
387
388 #[allow(clippy::too_many_arguments)]
389 fn forward(
390 &self,
391 xs: &Tensor,
392 attention_mask: Option<&Tensor>,
393 seqlen_offsets: &[usize],
394 position_ids: &[usize],
395 kv_cache: &mut KvCache,
396 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
397 flash_params: &FlashParams,
398 ) -> Result<Tensor> {
399 let residual = xs;
400 let xs = self.input_layernorm.forward(xs)?;
401 let xs = self.self_attn.forward(
402 &xs,
403 attention_mask,
404 seqlen_offsets,
405 position_ids,
406 kv_cache,
407 metadata,
408 flash_params,
409 )?;
410 let xs = (xs + residual)?;
411 let residual = &xs;
412 let xs = self
413 .mlp
414 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
415 residual + xs
416 }
417}
418
419pub struct Model {
420 embed_tokens: candle_nn::Embedding,
421 layers: Vec<DecoderLayer>,
422 norm: RmsNorm,
423 lm_head: Arc<dyn QuantMethod>,
424 device: Device,
425 cache: EitherCache,
426 max_seq_len: usize,
427 mapper: Box<dyn DeviceMapper + Send + Sync>,
428 sliding_window: Option<usize>,
429 cfg: ModelConfigMetadata,
430}
431
432impl Model {
433 pub fn new(
434 cfg: &Config,
435 vb: ShardedVarBuilder,
436 _is_gptx: bool,
437 normal_loading_metadata: NormalLoadingMetadata,
438 attention_mechanism: AttentionImplementation,
439 ) -> Result<Self> {
440 if let Some(ref quant_cfg) = &cfg.quantization_config {
441 tracing::info!(
442 "Using {} quantization: {}.",
443 quant_cfg.name(),
444 quant_cfg.get_bits_name(&vb)
445 );
446 }
447 let mapper = normal_loading_metadata.mapper;
448 let vb_m = vb.pp("model");
449
450 let embed_tokens = embedding(
451 cfg.vocab_size,
452 cfg.hidden_size,
453 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
454 &cfg.quantization_config,
455 )?;
456 let mut ropes = HashMap::new();
457 for layer_idx in 0..cfg.num_hidden_layers {
458 let device = mapper
459 .device_for(layer_idx, false)
460 .unwrap_or(&normal_loading_metadata.real_device);
461 ropes.insert(
462 device.location(),
463 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
464 );
465 }
466 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
467 let vb_l = vb_m.pp("layers");
468 for layer_idx in NiceProgressBar::<_, 'b'>(
469 0..cfg.num_hidden_layers,
470 "Loading repeating layers",
471 &normal_loading_metadata.multi_progress,
472 ) {
473 let device = mapper
474 .device_for(layer_idx, false)
475 .unwrap_or(&normal_loading_metadata.real_device);
476 let rotary_emb = ropes
477 .get(&device.location())
478 .expect("No RoPE for device location!")
479 .clone();
480 let paged_attn = match &attention_mechanism {
481 AttentionImplementation::Eager => None,
482 AttentionImplementation::PagedAttention => {
483 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
484 }
485 };
486 let layer = DecoderLayer::new(
487 rotary_emb.clone(),
488 cfg,
489 vb_l.pp(layer_idx),
490 &*mapper,
491 layer_idx,
492 normal_loading_metadata.loading_isq,
493 paged_attn,
494 )?;
495 layers.push(layer)
496 }
497 let norm = RmsNorm::new(
498 cfg.hidden_size,
499 cfg.rms_norm_eps,
500 mapper.set_nm_device(vb_m.pp("norm"), false),
501 )?;
502 let lm_head = if !cfg.tie_word_embeddings {
503 ReplicatedLayer::new(
504 cfg.hidden_size,
505 cfg.vocab_size,
506 &None,
507 false,
508 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
509 )?
510 } else {
511 ReplicatedLayer::from_linear(candle_nn::Linear::new(
512 mapper.cast_nm_device(
513 embed_tokens.embeddings(),
514 normal_loading_metadata.loading_isq,
515 )?,
516 None,
517 ))?
518 };
519 Ok(Self {
520 embed_tokens,
521 layers,
522 norm,
523 lm_head,
524 device: normal_loading_metadata.real_device,
525 cache: EitherCache::Normal(NormalCache::new_sliding(
526 cfg.num_hidden_layers,
527 cfg.max_position_embeddings,
528 cfg.sliding_window,
529 )),
530 max_seq_len: cfg.max_position_embeddings,
531 sliding_window: cfg.sliding_window,
532 cfg: ModelConfigMetadata {
533 max_seq_len: cfg.max_position_embeddings,
534 num_layers: cfg.num_hidden_layers,
535 hidden_size: cfg.hidden_size,
536 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
537 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
538 .max(1),
539 sliding_window: cfg.sliding_window,
540 k_head_dim: cfg.head_dim(),
541 v_head_dim: cfg.head_dim(),
542 },
543 mapper,
544 })
545 }
546
547 pub fn forward(
548 &self,
549 input_ids: &Tensor,
550 seqlen_offsets: &[usize],
551 position_ids: &[usize],
552 context_lens: Vec<(usize, usize)>,
553 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
554 flash_params: &FlashParams,
555 ) -> Result<Tensor> {
556 let mut xs = self.embed_tokens.forward(input_ids)?;
557 let cache = &mut self.cache.normal().0;
558 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
559 input_ids,
560 metadata
561 .as_ref()
562 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
563 .unwrap_or(cache as &dyn PastKvLenCache),
564 self.sliding_window,
565 xs.dtype(),
566 self.cfg.num_attn_heads,
567 )?;
568 let attention_mask = attention_mask.filter(|_| {
570 metadata
571 .as_ref()
572 .map(|(_, meta)| meta.is_first_prompt_chunk)
573 .unwrap_or(true)
574 });
575
576 for (i, layer) in self.layers.iter().enumerate() {
577 xs = self.mapper.map(xs, i)?;
578 xs = layer.forward(
579 &xs,
580 attention_mask
581 .as_ref()
582 .map(|m| m.to_device(xs.device()).unwrap())
583 .as_ref(),
584 seqlen_offsets,
585 position_ids,
586 &mut cache[i],
587 metadata
588 .as_ref()
589 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
590 flash_params,
591 )?
592 }
593 let xs = xs.to_device(&self.device)?;
594 let mut xs = xs.apply(&self.norm)?;
595 if let Some(t) = self.lm_head.quantized_act_type() {
596 xs = xs.to_dtype(t)?;
597 }
598 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
599 }
600}
601
602impl IsqModel for Model {
603 fn get_layers(
604 &mut self,
605 ) -> (
606 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
607 &dyn DeviceMapper,
608 ) {
609 let mut tensors = Vec::new();
610 tensors.push((&mut self.lm_head, None));
611 for (i, layer) in self.layers.iter_mut().enumerate() {
612 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
613 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
614 tensors.extend(
615 layer
616 .mlp
617 .get_isq_layers()
618 .into_iter()
619 .map(|m| (m, Some(i)))
620 .collect::<Vec<_>>(),
621 );
622 }
623 (tensors, &*self.mapper)
624 }
625
626 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
627 let uvb = UnVarBuilder::new();
628
629 let uvb_m = uvb.pp("model");
630 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
631 uvb_m.pp("norm").add(&self.norm);
632
633 for (layer_idx, layer) in self.layers.iter().enumerate() {
634 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
635 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
636 uvb_l
637 .pp("post_attention_layernorm")
638 .add(&layer.post_attention_layernorm);
639 }
640
641 uvb.to_safetensors()
642 }
643
644 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
645 let mut names = Vec::new();
647 names.push(None);
649 for i in 0..self.layers.len() {
650 names.push(Some(format!("blk.{i}.attn_qkv.weight")));
651 names.push(Some(format!("blk.{i}.attn_output.weight")));
652 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
653 names.push(Some(format!("blk.{i}.ffn_up.weight")));
654 names.push(Some(format!("blk.{i}.ffn_down.weight")));
655 }
656 Ok(names)
657 }
658}
659
660impl NormalModel for Model {
661 fn forward(
662 &self,
663 input_ids: &Tensor,
664 seqlen_offsets: &[usize],
665 context_lens: Vec<(usize, usize)>,
666 position_ids: Vec<usize>,
667 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
668 flash_params: &FlashParams,
669 ) -> Result<Tensor> {
670 self.forward(
671 input_ids,
672 seqlen_offsets,
673 &position_ids,
674 context_lens,
675 metadata,
676 flash_params,
677 )
678 }
679 fn xlora_forward(
680 &self,
681 _input_ids: &Tensor,
682 _input_ids_full: &Tensor,
683 _seqlen_offsets: &[usize],
684 _seqlen_offsets_full: &[usize],
685 _no_kv_cache: bool,
686 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
687 _context_lens: Vec<(usize, usize)>,
688 _position_ids: Vec<usize>,
689 _flash_params: &FlashParams,
690 _flash_params_full: &FlashParams,
691 ) -> Result<Tensor> {
692 unimplemented!()
693 }
694 fn cache(&self) -> &EitherCache {
695 &self.cache
696 }
697 fn cache_mut(&mut self) -> &mut EitherCache {
698 &mut self.cache
699 }
700 fn device(&self) -> &Device {
701 &self.device
702 }
703 fn is_xlora(&self) -> bool {
704 false
705 }
706 fn max_seq_len(&self) -> usize {
707 self.max_seq_len
708 }
709 fn config(&self) -> &ModelConfigMetadata {
710 &self.cfg
711 }
712}
713
714impl AnyMoeBaseModelMixin for Model {
715 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
716 let mut mlps = Vec::new();
717 for layer in &self.layers {
718 mlps.push(&*layer.mlp);
719 }
720 mlps
721 }
722 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
723 let mut mlps = Vec::new();
724 for layer in &mut self.layers {
725 mlps.push(&mut layer.mlp);
726 }
727 mlps
728 }
729 fn create_anymoe_layers(
730 &mut self,
731 additional_vbs: Vec<ShardedVarBuilder>,
732 config: AnyMoeConfig,
733 (prefix, mlp): (String, String),
734 mut layers: Vec<usize>,
735 expert_type: AnyMoeExpertType,
736 gate_vb: Option<ShardedVarBuilder>,
737 ) -> Result<()> {
738 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
739 if layers.is_empty() {
740 layers = (0..self.layers.len()).collect::<Vec<_>>();
741 }
742 for _ in 0..layers.len() {
743 experts.push(Vec::new());
744 }
745 for vb in additional_vbs {
746 let vb = vb.pp(&prefix);
747 for (layer, row) in experts.iter_mut().enumerate() {
748 if !layers.contains(&layer) {
749 continue;
750 }
751
752 let intermediate_size = self.layers[layer].mlp.get_params()[1];
753 let hidden_size = self.layers[layer].mlp.get_params()[0];
754 match expert_type {
755 AnyMoeExpertType::FineTuned => {
756 let (dtype, device) = self.layers[layer].mlp.dtype_device();
757 row.push(Box::new(Mlp::new(
758 &Config {
759 intermediate_size: self.layers[layer].mlp.get_params()[1],
760 hidden_size: self.layers[layer].mlp.get_params()[0],
761 ..Default::default()
762 },
763 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
764 )?));
765 }
766 AnyMoeExpertType::LoraAdapter {
767 rank,
768 alpha,
769 ref target_modules,
770 } => {
771 let vb_mlp = vb.pp(layer).pp(&mlp);
772
773 let gate_up_proj_delta =
774 if target_modules.contains(&"gate_up_proj".to_string()) {
775 Some(get_delta_from_lora_ab!(
776 vb_mlp,
777 rank,
778 alpha,
779 (hidden_size, 2 * intermediate_size),
780 "gate_up_proj"
781 ))
782 } else {
783 None
784 };
785 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
786 Some(get_delta_from_lora_ab!(
787 vb_mlp,
788 rank,
789 alpha,
790 (hidden_size, intermediate_size),
791 "down_proj"
792 ))
793 } else {
794 None
795 };
796
797 row.push(
798 self.layers[layer]
799 .mlp
800 .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
801 );
802 }
803 }
804 }
805 }
806 for (layer, expert) in layers.into_iter().zip(experts) {
807 let mut experts_all = vec![self.layers[layer].mlp.clone()];
808 experts_all.extend(expert);
809 let (dtype, device) = self.layers[layer].mlp.dtype_device();
810 self.layers[layer].mlp = Box::new(MoeMlp::new(
811 experts_all,
812 config.clone(),
813 dtype,
814 &device,
815 layer,
816 gate_vb.as_ref(),
817 )?);
818 }
819 Ok(())
820 }
821 fn amoe_supported(&self) -> bool {
822 true
823 }
824}