1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9 ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14 attention::SdpaParams,
15 device_map::DeviceMapper,
16 get_delta_from_lora_ab,
17 layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19 pipeline::{
20 extract_logits,
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23 },
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27#[derive(Debug, Clone, Default, serde::Serialize)]
28pub struct Config {
29 pub attention_bias: bool,
30 pub head_dim: usize,
31 pub hidden_act: Option<Activation>,
33 pub hidden_activation: Option<Activation>,
34 pub hidden_size: usize,
35 pub intermediate_size: usize,
36 pub num_attention_heads: usize,
37 pub num_hidden_layers: usize,
38 pub num_key_value_heads: usize,
39 pub rms_norm_eps: f64,
40 pub rope_theta: f64,
41 pub vocab_size: usize,
42 pub sliding_window: usize,
43 pub attn_logit_softcapping: Option<f64>,
44 pub final_logit_softcapping: Option<f64>,
45 pub query_pre_attn_scalar: usize,
46 pub max_position_embeddings: usize,
47 pub quantization_config: Option<QuantizedConfig>,
48 pub use_flash_attn: bool,
49 #[allow(dead_code)]
50 pub tie_word_embeddings: bool,
51}
52
53impl Config {
54 pub fn hidden_act(&self) -> Result<Activation> {
55 match (self.hidden_act, self.hidden_activation) {
56 (None, Some(act)) | (Some(act), None) => Ok(act),
57 (Some(act), Some(_)) => {
58 Ok(act)
60 }
61 (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
62 }
63 }
64}
65struct Attention {
66 q_proj: Arc<dyn QuantMethod>,
67 k_proj: Arc<dyn QuantMethod>,
68 v_proj: Arc<dyn QuantMethod>,
69 o_proj: Arc<dyn QuantMethod>,
70 num_heads: usize,
71 num_kv_heads: usize,
72 head_dim: usize,
73 rotary_emb: Arc<RotaryEmbedding>,
74 use_sliding_window: bool,
75 paged_attn: Option<PagedAttention>,
76 sdpa_params: SdpaParams,
77}
78
79impl Attention {
80 fn new(
81 rotary_emb: Arc<RotaryEmbedding>,
82 cfg: &Config,
83 layer_idx: usize,
84 vb: ShardedVarBuilder,
85 paged_attn: Option<PagedAttention>,
86 comm: &Arc<mistralrs_quant::Comm>,
87 ) -> Result<Self> {
88 let hidden_sz = cfg.hidden_size;
89 let num_heads = cfg.num_attention_heads;
90 let num_kv_heads = cfg.num_key_value_heads;
91 let head_dim = cfg.head_dim;
92 let bias = cfg.attention_bias;
93 let q_proj = ColumnParallelLayer::new(
94 hidden_sz,
95 num_heads * head_dim,
96 &cfg.quantization_config,
97 bias,
98 comm,
99 vb.pp("q_proj"),
100 )?;
101 let kv_shard = mistralrs_quant::compute_kv_shard(
102 cfg.num_key_value_heads,
103 cfg.hidden_size / cfg.num_attention_heads,
104 comm,
105 );
106 let k_proj = ColumnParallelLayer::new_with_shard(
107 hidden_sz,
108 num_kv_heads * head_dim,
109 &cfg.quantization_config,
110 bias,
111 comm,
112 kv_shard,
113 vb.pp("k_proj"),
114 )?;
115 let v_proj = ColumnParallelLayer::new_with_shard(
116 hidden_sz,
117 num_kv_heads * head_dim,
118 &cfg.quantization_config,
119 bias,
120 comm,
121 kv_shard,
122 vb.pp("v_proj"),
123 )?;
124 let o_proj = RowParallelLayer::new(
125 num_heads * head_dim,
126 hidden_sz,
127 &cfg.quantization_config,
128 bias,
129 comm,
130 vb.pp("o_proj"),
131 )?;
132 let sliding_window = if layer_idx % 2 == 0 {
133 Some(cfg.sliding_window)
135 } else {
136 None
137 };
138 Ok(Self {
139 q_proj,
140 k_proj,
141 v_proj,
142 o_proj,
143 num_heads: num_heads / comm.world_size(),
144 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
145 head_dim,
146 rotary_emb,
147 use_sliding_window: layer_idx % 2 == 0, paged_attn,
149 sdpa_params: SdpaParams {
150 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
151 cfg.num_key_value_heads,
152 cfg.num_attention_heads,
153 comm,
154 ),
155 use_flash_attn: cfg.use_flash_attn,
156 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
157 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
158 sliding_window,
159 },
160 })
161 }
162
163 #[allow(clippy::too_many_arguments)]
164 fn forward(
165 &self,
166 xs: &Tensor,
167 attention_mask: Option<&Tensor>,
168 sliding_attention_mask: Option<&Tensor>,
169 seqlen_offsets: &[usize],
170 kv_cache: &mut KvCache,
171 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
172 flash_params: &FlashParams,
173 ) -> Result<Tensor> {
174 let (b_sz, q_len, _) = xs.dims3()?;
175
176 let original_dtype = xs.dtype();
177 let mut xs = xs.clone();
178 if let Some(t) = self.q_proj.quantized_act_type() {
179 xs = xs.to_dtype(t)?;
180 }
181 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
182 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
183 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
184 if self.q_proj.quantized_act_type().is_some() {
185 q = q.to_dtype(original_dtype)?;
186 k = k.to_dtype(original_dtype)?;
187 v = v.to_dtype(original_dtype)?;
188 }
189
190 let (q, k, v) = if q_len != 1 {
191 let q = q
192 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
193 .transpose(1, 2)?;
194 let k = k
195 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196 .transpose(1, 2)?;
197 let v = v
198 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
199 .transpose(1, 2)?;
200 (q, k, v)
201 } else {
202 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
203 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
205 (q, k, v)
206 };
207
208 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
209
210 let mask = if self.use_sliding_window {
211 sliding_attention_mask
212 } else {
213 attention_mask
214 };
215
216 let mut attn_output = match &self.paged_attn {
217 Some(paged_attn) => match metadata {
218 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
219 &q,
220 &k,
221 &v,
222 attention_mask,
223 Some(key_cache),
224 Some(value_cache),
225 input_metadata,
226 &self.sdpa_params,
227 Some(flash_params),
228 )?,
229 None => {
230 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
233 assert!(attention_mask.is_some());
235 paged_attn.forward(
236 &q,
237 &k,
238 &v,
239 attention_mask,
240 None,
241 None,
242 &input_metadata,
243 &self.sdpa_params,
244 Some(flash_params),
245 )?
246 }
247 },
248 None => {
249 let (k, v) = kv_cache.append(&k, &v)?;
251
252 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
253 }
254 };
255
256 if let Some(t) = self.q_proj.quantized_act_type() {
257 attn_output = attn_output.to_dtype(t)?;
258 }
259 attn_output = if attention_mask.is_some() {
260 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
261 } else {
262 attn_output.reshape((b_sz, q_len, ()))?
263 };
264 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
265 if self.q_proj.quantized_act_type().is_some() {
266 res = res.to_dtype(original_dtype)?;
267 }
268 Ok(res)
269 }
270}
271
272struct DecoderLayer {
273 self_attn: Attention,
274 mlp: Box<dyn MlpLayer>,
275 input_layernorm: RmsNorm,
276 post_attention_layernorm: RmsNorm,
277 pre_feedforward_layernorm: RmsNorm,
278 post_feedforward_layernorm: RmsNorm,
279}
280
281impl DecoderLayer {
282 #[allow(clippy::too_many_arguments)]
283 fn new(
284 rotary_emb: Arc<RotaryEmbedding>,
285 cfg: &Config,
286 vb: ShardedVarBuilder,
287 mapper: &dyn DeviceMapper,
288 layer_idx: usize,
289 loading_isq: bool,
290 paged_attn: Option<PagedAttention>,
291 comm: &Arc<mistralrs_quant::Comm>,
292 ) -> Result<Self> {
293 let self_attn = Attention::new(
294 rotary_emb,
295 cfg,
296 layer_idx,
297 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298 paged_attn,
299 comm,
300 )?;
301 let mlp = Mlp::new(
302 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
303 cfg.hidden_size,
304 cfg.intermediate_size,
305 &cfg.quantization_config,
306 cfg.hidden_act()?,
307 comm,
308 )?;
309 let input_layernorm = RmsNorm::new_gemma(
310 cfg.hidden_size,
311 cfg.rms_norm_eps,
312 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
313 )?;
314 let post_attention_layernorm = RmsNorm::new_gemma(
315 cfg.hidden_size,
316 cfg.rms_norm_eps,
317 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
318 )?;
319 let pre_feedforward_layernorm = RmsNorm::new_gemma(
320 cfg.hidden_size,
321 cfg.rms_norm_eps,
322 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
323 )?;
324 let post_feedforward_layernorm = RmsNorm::new_gemma(
325 cfg.hidden_size,
326 cfg.rms_norm_eps,
327 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
328 )?;
329 Ok(Self {
330 self_attn,
331 mlp: Box::new(mlp),
332 input_layernorm,
333 post_attention_layernorm,
334 pre_feedforward_layernorm,
335 post_feedforward_layernorm,
336 })
337 }
338
339 #[allow(clippy::too_many_arguments)]
340 fn forward(
341 &self,
342 xs: &Tensor,
343 attention_mask: Option<&Tensor>,
344 sliding_attention_mask: Option<&Tensor>,
345 seqlen_offsets: &[usize],
346 kv_cache: &mut KvCache,
347 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
348 flash_params: &FlashParams,
349 ) -> Result<Tensor> {
350 let residual = xs;
351 let xs = self.input_layernorm.forward(xs)?;
352 let xs = self
353 .self_attn
354 .forward(
355 &xs,
356 attention_mask,
357 sliding_attention_mask,
358 seqlen_offsets,
359 kv_cache,
360 metadata,
361 flash_params,
362 )?
363 .apply(&self.post_attention_layernorm)?;
364 let xs = (xs + residual)?;
365 let residual = &xs;
366 let xs = self
367 .mlp
368 .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
369 .apply(&self.post_feedforward_layernorm)?;
370 residual + xs
371 }
372}
373
374pub struct Model {
375 embed_tokens: candle_nn::Embedding,
376 layers: Vec<DecoderLayer>,
377 norm: RmsNorm,
378 lm_head: Arc<dyn QuantMethod>,
379 hidden_size: usize,
380 device: Device,
381 cache: EitherCache,
382 max_seq_len: usize,
383 mapper: Box<dyn DeviceMapper + Send + Sync>,
384 sliding_window: usize,
385 final_logit_softcapping: Option<f64>,
386 cfg: ModelConfigMetadata,
387}
388
389impl Model {
390 pub fn new(
391 cfg: &Config,
392 vb: ShardedVarBuilder,
393 is_gptx: bool,
394 normal_loading_metadata: NormalLoadingMetadata,
395 attention_mechanism: AttentionImplementation,
396 ) -> Result<Self> {
397 if let Some(ref quant_cfg) = &cfg.quantization_config {
398 tracing::info!(
399 "Using {} quantization: {}.",
400 quant_cfg.name(),
401 quant_cfg.get_bits_name(&vb)
402 );
403 }
404 let mapper = normal_loading_metadata.mapper;
405
406 let vb_m = vb.pp("model");
407 let embed_tokens = embedding(
408 cfg.vocab_size,
409 cfg.hidden_size,
410 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
411 &cfg.quantization_config,
412 )?;
413 let mut ropes = HashMap::new();
414 for layer_idx in 0..cfg.num_hidden_layers {
415 let device = mapper
416 .device_for(layer_idx, false)
417 .unwrap_or(&normal_loading_metadata.real_device);
418 ropes.insert(
419 device.location(),
420 Arc::new(RotaryEmbedding::new(
421 cfg.rope_theta as f32,
422 cfg.head_dim,
423 cfg.max_position_embeddings,
424 device,
425 is_gptx,
426 vb_m.dtype(),
427 )?),
428 );
429 }
430 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
431 let vb_l = vb_m.pp("layers");
432 for layer_idx in NiceProgressBar::<_, 'b'>(
433 0..cfg.num_hidden_layers,
434 "Loading repeating layers",
435 &normal_loading_metadata.multi_progress,
436 ) {
437 let device = mapper
438 .device_for(layer_idx, false)
439 .unwrap_or(&normal_loading_metadata.real_device);
440 let rotary_emb = ropes
441 .get(&device.location())
442 .expect("No RoPE for device location!")
443 .clone();
444 let paged_attn = match &attention_mechanism {
445 AttentionImplementation::Eager => None,
446 AttentionImplementation::PagedAttention => {
447 Some(PagedAttention::new(cfg.head_dim, device, None)?)
448 }
449 };
450 let comm = mapper.get_comm_for(layer_idx)?;
451 let layer = DecoderLayer::new(
452 rotary_emb.clone(),
453 cfg,
454 vb_l.pp(layer_idx),
455 &*mapper,
456 layer_idx,
457 normal_loading_metadata.loading_isq,
458 paged_attn,
459 &comm,
460 )?;
461 layers.push(layer)
462 }
463 let norm = RmsNorm::new_gemma(
464 cfg.hidden_size,
465 cfg.rms_norm_eps,
466 mapper.set_nm_device(vb_m.pp("norm"), false),
467 )?;
468 let lm_head = mapper.cast_nm_device(
469 embed_tokens.embeddings(),
470 normal_loading_metadata.loading_isq,
471 )?;
472 Ok(Self {
473 embed_tokens,
474 layers,
475 norm,
476 lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
477 Linear::new(lm_head, None),
478 ))?),
479 device: normal_loading_metadata.real_device,
480 hidden_size: cfg.hidden_size,
481 cache: EitherCache::Normal(NormalCache::new_sliding(
482 cfg.num_hidden_layers,
483 cfg.max_position_embeddings,
484 Some(cfg.sliding_window),
485 )),
486 max_seq_len: cfg.max_position_embeddings,
487 sliding_window: cfg.sliding_window,
488 final_logit_softcapping: cfg.final_logit_softcapping,
489 cfg: ModelConfigMetadata {
490 max_seq_len: cfg.max_position_embeddings,
491 num_layers: cfg.num_hidden_layers,
492 hidden_size: cfg.hidden_size,
493 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
494 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
495 .max(1),
496 sliding_window: None,
497 k_head_dim: cfg.head_dim,
498 v_head_dim: cfg.head_dim,
499 },
500 mapper,
501 })
502 }
503
504 pub fn forward(
505 &self,
506 input_ids: &Tensor,
507 seqlen_offsets: &[usize],
508 context_lens: Vec<(usize, usize)>,
509 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
510 flash_params: &FlashParams,
511 ) -> Result<Tensor> {
512 let xs = self.embed_tokens.forward(input_ids)?;
513 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
514 let cache = &mut self.cache.normal().0;
515 let attention_mask = CausalMasker.make_causal_mask_matrix(
516 input_ids,
517 &*cache,
518 xs.dtype(),
519 self.cfg.num_attn_heads,
520 )?;
521 let attention_mask = attention_mask.filter(|_| {
523 metadata
524 .as_ref()
525 .map(|(_, meta)| meta.is_first_prompt_chunk)
526 .unwrap_or(true)
527 });
528 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
529 input_ids,
530 &*cache,
531 Some(self.sliding_window),
532 xs.dtype(),
533 self.cfg.num_attn_heads,
534 )?;
535 let sliding_attention_mask = sliding_attention_mask.filter(|_| {
537 metadata
538 .as_ref()
539 .map(|(_, meta)| meta.is_first_prompt_chunk)
540 .unwrap_or(true)
541 });
542 for (i, layer) in self.layers.iter().enumerate() {
543 xs = self.mapper.map(xs, i)?;
544 xs = layer.forward(
545 &xs,
546 attention_mask
547 .as_ref()
548 .map(|m| m.to_device(xs.device()).unwrap())
549 .as_ref(),
550 sliding_attention_mask
551 .as_ref()
552 .map(|m| m.to_device(xs.device()).unwrap())
553 .as_ref(),
554 seqlen_offsets,
555 &mut cache[i],
556 metadata
557 .as_ref()
558 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
559 flash_params,
560 )?;
561 }
562 let xs = xs.to_device(&self.device)?;
563 let mut xs = xs.apply(&self.norm)?;
564 if let Some(t) = self.lm_head.quantized_act_type() {
565 xs = xs.to_dtype(t)?;
566 }
567
568 let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
569
570 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
571 xs = (xs / final_logit_softcapping)?;
572 xs = xs.tanh()?;
573 xs = (xs * final_logit_softcapping)?;
574 }
575
576 extract_logits(&xs, context_lens)
577 }
578}
579
580impl IsqModel for Model {
581 fn get_layers(
582 &mut self,
583 ) -> (
584 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
585 &dyn DeviceMapper,
586 ) {
587 let mut tensors = Vec::new();
588 tensors.push((&mut self.lm_head, None));
589 for (i, layer) in self.layers.iter_mut().enumerate() {
590 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
591 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
592 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
593 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
594 tensors.extend(
595 layer
596 .mlp
597 .get_isq_layers()
598 .into_iter()
599 .map(|m| (m, Some(i)))
600 .collect::<Vec<_>>(),
601 );
602 }
603 (tensors, &*self.mapper)
604 }
605
606 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
607 let uvb = UnVarBuilder::new();
608
609 let uvb_m = uvb.pp("model");
610 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
611 uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
612
613 for (layer_idx, layer) in self.layers.iter().enumerate() {
614 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
615 uvb_l
616 .pp("input_layernorm")
617 .add(&layer.input_layernorm.undo_gemma().unwrap());
618 uvb_l
619 .pp("post_attention_layernorm")
620 .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
621 uvb_l
622 .pp("pre_feedforward_layernorm")
623 .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
624 uvb_l
625 .pp("post_feedforward_layernorm")
626 .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
627 }
628
629 uvb.to_safetensors()
630 }
631
632 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
633 let mut names = Vec::new();
635 names.push(None);
637 for i in 0..self.layers.len() {
638 names.push(Some(format!("blk.{i}.attn_q.weight")));
639 names.push(Some(format!("blk.{i}.attn_k.weight")));
640 names.push(Some(format!("blk.{i}.attn_v.weight")));
641 names.push(Some(format!("blk.{i}.attn_output.weight")));
642 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
643 names.push(Some(format!("blk.{i}.ffn_up.weight")));
644 names.push(Some(format!("blk.{i}.ffn_down.weight")));
645 }
646 Ok(names)
647 }
648}
649
650impl NormalModel for Model {
651 fn forward(
652 &self,
653 input_ids: &Tensor,
654 seqlen_offsets: &[usize],
655 context_lens: Vec<(usize, usize)>,
656 _position_ids: Vec<usize>,
657 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
658 flash_params: &FlashParams,
659 ) -> Result<Tensor> {
660 self.forward(
661 input_ids,
662 seqlen_offsets,
663 context_lens,
664 metadata,
665 flash_params,
666 )
667 }
668 fn xlora_forward(
669 &self,
670 _input_ids: &Tensor,
671 _input_ids_full: &Tensor,
672 _seqlen_offsets: &[usize],
673 _seqlen_offsets_full: &[usize],
674 _no_kv_cache: bool,
675 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
676 _context_lens: Vec<(usize, usize)>,
677 _position_ids: Vec<usize>,
678 _flash_params: &FlashParams,
679 _flash_params_full: &FlashParams,
680 ) -> Result<Tensor> {
681 unimplemented!()
682 }
683 fn cache(&self) -> &EitherCache {
684 &self.cache
685 }
686 fn cache_mut(&mut self) -> &mut EitherCache {
687 &mut self.cache
688 }
689 fn device(&self) -> &Device {
690 &self.device
691 }
692 fn is_xlora(&self) -> bool {
693 false
694 }
695 fn max_seq_len(&self) -> usize {
696 self.max_seq_len
697 }
698 fn config(&self) -> &ModelConfigMetadata {
699 &self.cfg
700 }
701}
702
703impl AnyMoeBaseModelMixin for Model {
704 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
705 let mut mlps = Vec::new();
706 for layer in &self.layers {
707 mlps.push(&*layer.mlp);
708 }
709 mlps
710 }
711 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
712 let mut mlps = Vec::new();
713 for layer in &mut self.layers {
714 mlps.push(&mut layer.mlp);
715 }
716 mlps
717 }
718 fn create_anymoe_layers(
719 &mut self,
720 additional_vbs: Vec<ShardedVarBuilder>,
721 config: AnyMoeConfig,
722 (prefix, mlp): (String, String),
723 mut layers: Vec<usize>,
724 expert_type: AnyMoeExpertType,
725 gate_vb: Option<ShardedVarBuilder>,
726 ) -> Result<()> {
727 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
728 if layers.is_empty() {
729 layers = (0..self.layers.len()).collect::<Vec<_>>();
730 }
731 for _ in 0..layers.len() {
732 experts.push(Vec::new());
733 }
734 for vb in additional_vbs {
735 let vb = vb.pp(&prefix);
736 for (layer, row) in experts.iter_mut().enumerate() {
737 if !layers.contains(&layer) {
738 continue;
739 }
740
741 let intermediate_size = self.layers[layer].mlp.get_params()[1];
742 let hidden_size = self.layers[layer].mlp.get_params()[0];
743 match expert_type {
744 AnyMoeExpertType::FineTuned => {
745 let (dtype, device) = self.layers[layer].mlp.dtype_device();
746 row.push(Box::new(Mlp::replicate(
747 self.layers[layer].mlp.get_params(),
748 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
749 self.layers[layer].mlp.hidden_act(),
750 &self.mapper.get_comm_for(layer)?,
751 )?));
752 }
753 AnyMoeExpertType::LoraAdapter {
754 rank,
755 alpha,
756 ref target_modules,
757 } => {
758 let vb_mlp = vb.pp(layer).pp(&mlp);
759
760 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
761 Some(get_delta_from_lora_ab!(
762 vb_mlp,
763 rank,
764 alpha,
765 (hidden_size, intermediate_size),
766 "gate_proj"
767 ))
768 } else {
769 None
770 };
771 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
772 Some(get_delta_from_lora_ab!(
773 vb_mlp,
774 rank,
775 alpha,
776 (hidden_size, intermediate_size),
777 "up_proj"
778 ))
779 } else {
780 None
781 };
782 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
783 Some(get_delta_from_lora_ab!(
784 vb_mlp,
785 rank,
786 alpha,
787 (intermediate_size, hidden_size),
788 "down_proj"
789 ))
790 } else {
791 None
792 };
793
794 row.push(self.layers[layer].mlp.new_added_delta(vec![
795 gate_proj_delta,
796 up_proj_delta,
797 down_proj_delta,
798 ])?);
799 }
800 }
801 }
802 }
803 for (layer, expert) in layers.into_iter().zip(experts) {
804 let mut experts_all = vec![self.layers[layer].mlp.clone()];
805 experts_all.extend(expert);
806 let (dtype, device) = self.layers[layer].mlp.dtype_device();
807 self.layers[layer].mlp = Box::new(MoeMlp::new(
808 experts_all,
809 config.clone(),
810 dtype,
811 &device,
812 layer,
813 gate_vb.as_ref(),
814 )?);
815 }
816 Ok(())
817 }
818 fn amoe_supported(&self) -> bool {
819 true
820 }
821}