1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9 ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15 amoe::AnyMoeBaseModelMixin,
16 attention::SdpaParams,
17 device_map::DeviceMapper,
18 layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19 layers_masker::PastKvLenCache,
20 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21 pipeline::{
22 extract_logits,
23 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25 },
26 serde_default_fn,
27 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35 pub(crate) vocab_size: usize,
36 pub(crate) hidden_size: usize,
37 pub(crate) intermediate_size: usize,
38 pub(crate) num_hidden_layers: usize,
39 pub(crate) num_attention_heads: usize,
40 pub(crate) num_key_value_heads: usize,
41 pub(crate) hidden_act: Activation,
42 pub(crate) max_position_embeddings: usize,
43 pub(crate) rms_norm_eps: f64,
44 pub(crate) rope_theta: f64,
45 pub(crate) sliding_window: Option<usize>,
46 pub(crate) num_experts_per_tok: usize,
47 pub(crate) num_local_experts: usize,
48 pub(crate) quantization_config: Option<QuantizedConfig>,
49 #[serde(default = "word_emb_default")]
50 pub(crate) tie_word_embeddings: bool,
51}
52
53struct Attention {
54 q_proj: Arc<dyn QuantMethod>,
55 k_proj: Arc<dyn QuantMethod>,
56 v_proj: Arc<dyn QuantMethod>,
57 o_proj: Arc<dyn QuantMethod>,
58 num_heads: usize,
59 num_kv_heads: usize,
60 head_dim: usize,
61 rotary_emb: Arc<RotaryEmbedding>,
62 paged_attn: Option<PagedAttention>,
63 sdpa_params: SdpaParams,
64}
65
66impl Attention {
67 fn new(
68 rotary_emb: Arc<RotaryEmbedding>,
69 cfg: &Config,
70 vb: ShardedVarBuilder,
71 paged_attn: Option<PagedAttention>,
72 comm: &Arc<mistralrs_quant::Comm>,
73 ) -> Result<Self> {
74 let hidden_sz = cfg.hidden_size;
75 let num_heads = cfg.num_attention_heads;
76 let num_kv_heads = cfg.num_key_value_heads;
77 let head_dim = hidden_sz / num_heads;
78 let q_proj = ColumnParallelLayer::new(
79 hidden_sz,
80 num_heads * head_dim,
81 &cfg.quantization_config,
82 false,
83 comm,
84 vb.pp("q_proj"),
85 )?;
86 let kv_shard = mistralrs_quant::compute_kv_shard(
87 cfg.num_key_value_heads,
88 cfg.hidden_size / cfg.num_attention_heads,
89 comm,
90 );
91 let k_proj = ColumnParallelLayer::new_with_shard(
92 hidden_sz,
93 num_kv_heads * head_dim,
94 &cfg.quantization_config,
95 false,
96 comm,
97 kv_shard,
98 vb.pp("k_proj"),
99 )?;
100 let v_proj = ColumnParallelLayer::new_with_shard(
101 hidden_sz,
102 num_kv_heads * head_dim,
103 &cfg.quantization_config,
104 false,
105 comm,
106 kv_shard,
107 vb.pp("v_proj"),
108 )?;
109 let o_proj = RowParallelLayer::new(
110 num_heads * head_dim,
111 hidden_sz,
112 &cfg.quantization_config,
113 false,
114 comm,
115 vb.pp("o_proj"),
116 )?;
117 Ok(Self {
118 q_proj,
119 k_proj,
120 v_proj,
121 o_proj,
122 num_heads: num_heads / comm.world_size(),
123 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
124 head_dim,
125 rotary_emb,
126 paged_attn,
127 sdpa_params: SdpaParams {
128 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
129 cfg.num_key_value_heads,
130 cfg.num_attention_heads,
131 comm,
132 ),
133 softcap: None,
134 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
135 sliding_window: cfg.sliding_window,
136 },
137 })
138 }
139
140 #[allow(clippy::too_many_arguments)]
141 fn forward(
142 &self,
143 xs: &Tensor,
144 attention_mask: Option<&Tensor>,
145 seqlen_offsets: &[usize],
146 kv_cache: &mut KvCache,
147 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148 flash_params: &FlashParams,
149 ) -> Result<Tensor> {
150 let (b_sz, q_len, _) = xs.dims3()?;
151
152 let original_dtype = xs.dtype();
153 let mut xs = xs.clone();
154 if let Some(t) = self.q_proj.quantized_act_type() {
155 xs = xs.to_dtype(t)?;
156 }
157 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
158 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
159 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
160 if self.q_proj.quantized_act_type().is_some() {
161 q = q.to_dtype(original_dtype)?;
162 k = k.to_dtype(original_dtype)?;
163 v = v.to_dtype(original_dtype)?;
164 }
165
166 let (q, k, v) = if q_len != 1 {
167 let q = q
168 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169 .transpose(1, 2)?;
170 let k = k
171 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172 .transpose(1, 2)?;
173 let v = v
174 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175 .transpose(1, 2)?;
176 (q, k, v)
177 } else {
178 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181 (q, k, v)
182 };
183
184 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186 let mut attn_output = match &self.paged_attn {
187 Some(paged_attn) => match metadata {
188 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
189 &q,
190 &k,
191 &v,
192 attention_mask,
193 Some(key_cache),
194 Some(value_cache),
195 input_metadata,
196 &self.sdpa_params,
197 Some(flash_params),
198 )?,
199 None => {
200 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
203 assert!(attention_mask.is_some());
205 paged_attn.forward(
206 &q,
207 &k,
208 &v,
209 attention_mask,
210 None,
211 None,
212 &input_metadata,
213 &self.sdpa_params,
214 Some(flash_params),
215 )?
216 }
217 },
218 None => {
219 let (k, v) = kv_cache.append(&k, &v)?;
220
221 Sdpa.run_attention(
222 &q,
223 &k,
224 &v,
225 attention_mask,
226 Some(flash_params),
227 &self.sdpa_params,
228 )?
229 }
230 };
231
232 if let Some(t) = self.q_proj.quantized_act_type() {
233 attn_output = attn_output.to_dtype(t)?;
234 }
235 attn_output = if attention_mask.is_some() {
236 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
237 } else {
238 attn_output.reshape((b_sz, q_len, ()))?
239 };
240 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
241 if self.q_proj.quantized_act_type().is_some() {
242 res = res.to_dtype(original_dtype)?;
243 }
244 Ok(res)
245 }
246}
247
248#[derive(Clone)]
249struct BlockSparseTop2MLP {
250 w1: Arc<dyn QuantMethod>,
251 w2: Arc<dyn QuantMethod>,
252 w3: Arc<dyn QuantMethod>,
253 act_fn: Activation,
254}
255
256impl BlockSparseTop2MLP {
257 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
258 let hidden_sz = cfg.hidden_size;
259 let intermediate_sz = cfg.intermediate_size;
260 let w1 = ColumnParallelLayer::new(
261 hidden_sz,
262 intermediate_sz,
263 &cfg.quantization_config,
264 false,
265 comm,
266 vb.pp("w1"),
267 )?;
268 let w2 = RowParallelLayer::new(
269 intermediate_sz,
270 hidden_sz,
271 &cfg.quantization_config,
272 false,
273 comm,
274 vb.pp("w2"),
275 )?;
276 let w3 = ColumnParallelLayer::new(
277 hidden_sz,
278 intermediate_sz,
279 &cfg.quantization_config,
280 false,
281 comm,
282 vb.pp("w3"),
283 )?;
284 Ok(Self {
285 w1,
286 w2,
287 w3,
288 act_fn: cfg.hidden_act,
289 })
290 }
291}
292
293impl Module for BlockSparseTop2MLP {
294 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
295 let original_dtype = xs.dtype();
296 let mut xs = xs.clone();
297 if let Some(t) = self.w1.quantized_act_type() {
298 xs = xs.to_dtype(t)?;
299 }
300 let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
301 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
302 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
303 if self.w1.quantized_act_type().is_some() {
304 res = res.to_dtype(original_dtype)?;
305 }
306 Ok(res)
307 }
308}
309
310#[derive(Clone)]
311struct SparseMoeBlock {
312 gate: Arc<dyn QuantMethod>,
313 experts: Vec<BlockSparseTop2MLP>,
314 num_experts_per_tok: usize,
315}
316
317impl SparseMoeBlock {
318 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
319 let gate = mistralrs_quant::linear_no_bias(
320 cfg.hidden_size,
321 cfg.num_local_experts,
322 &cfg.quantization_config,
323 vb.pp("gate"),
324 )?;
325 let mut experts = Vec::with_capacity(cfg.num_local_experts);
326 let vb = vb.pp("experts");
327 for idx in 0..cfg.num_local_experts {
328 let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
329 experts.push(expert)
330 }
331 Ok(SparseMoeBlock {
332 gate,
333 experts,
334 num_experts_per_tok: cfg.num_experts_per_tok,
335 })
336 }
337}
338
339impl Module for SparseMoeBlock {
340 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
341 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
342 let xs = xs.reshape(((), hidden_dim))?;
343
344 let original_dtype = xs.dtype();
345 let mut xs = xs.clone();
346 if let Some(t) = self.gate.quantized_act_type() {
347 xs = xs.to_dtype(t)?;
348 }
349 let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
350 if self.gate.quantized_act_type().is_some() {
351 router_logits = router_logits.to_dtype(original_dtype)?;
352 }
353
354 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
355
356 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
359
360 let mut top_x = vec![vec![]; self.experts.len()];
363 let mut selected_rws = vec![vec![]; self.experts.len()];
364 for (row_idx, rw) in routing_weights.iter().enumerate() {
365 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
366 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
367 let mut sum_routing_weights = 0f32;
368 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
369 let expert_idx = expert_idx as usize;
370 let routing_weight = rw[expert_idx];
371 sum_routing_weights += routing_weight;
372 top_x[expert_idx].push(row_idx as u32);
373 }
374 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
375 let expert_idx = expert_idx as usize;
376 let routing_weight = rw[expert_idx];
377 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
378 }
379 }
380
381 let mut ys = xs.zeros_like()?;
385 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
386 let top_x = &top_x[expert_idx];
387 if top_x.is_empty() {
388 continue;
389 }
390 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
391 let selected_rws =
392 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
393 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
397 let current_hidden_states = expert_layer.forward(¤t_state)?;
399 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
400 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
401 }
402
403 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
404 Ok(ys)
405 }
406}
407
408struct DecoderLayer {
409 self_attn: Attention,
410 block_sparse_moe: SparseMoeBlock,
411 input_layernorm: RmsNorm,
412 post_attention_layernorm: RmsNorm,
413}
414
415impl DecoderLayer {
416 #[allow(clippy::too_many_arguments)]
417 fn new(
418 rotary_emb: Arc<RotaryEmbedding>,
419 cfg: &Config,
420 vb: ShardedVarBuilder,
421 mapper: &dyn DeviceMapper,
422 layer_idx: usize,
423 loading_isq: bool,
424 paged_attn: Option<PagedAttention>,
425 comm: &Arc<mistralrs_quant::Comm>,
426 ) -> Result<Self> {
427 let self_attn = Attention::new(
428 rotary_emb,
429 cfg,
430 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
431 paged_attn,
432 comm,
433 )?;
434 let block_sparse_moe = SparseMoeBlock::new(
435 cfg,
436 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
437 comm,
438 )?;
439 let input_layernorm = RmsNorm::new(
440 cfg.hidden_size,
441 cfg.rms_norm_eps,
442 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
443 )?;
444 let post_attention_layernorm = RmsNorm::new(
445 cfg.hidden_size,
446 cfg.rms_norm_eps,
447 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
448 )?;
449 Ok(Self {
450 self_attn,
451 block_sparse_moe,
452 input_layernorm,
453 post_attention_layernorm,
454 })
455 }
456
457 #[allow(clippy::too_many_arguments)]
458 fn forward(
459 &self,
460 xs: &Tensor,
461 attention_mask: Option<&Tensor>,
462 seqlen_offsets: &[usize],
463 kv_cache: &mut KvCache,
464 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
465 flash_params: &FlashParams,
466 ) -> Result<Tensor> {
467 let residual = xs;
468 let xs = self.input_layernorm.forward(xs)?;
469 let xs = self.self_attn.forward(
470 &xs,
471 attention_mask,
472 seqlen_offsets,
473 kv_cache,
474 metadata,
475 flash_params,
476 )?;
477 let xs = (xs + residual)?;
478 let residual = &xs;
479 let xs = xs
480 .apply(&self.post_attention_layernorm)?
481 .apply(&self.block_sparse_moe)?
482 .to_dtype(residual.dtype())?;
483 residual + xs
484 }
485}
486
487pub struct Model {
488 embed_tokens: candle_nn::Embedding,
489 layers: Vec<DecoderLayer>,
490 norm: RmsNorm,
491 lm_head: Arc<dyn QuantMethod>,
492 sliding_window: Option<usize>,
493 device: Device,
494 cache: EitherCache,
495 max_seq_len: usize,
496 mapper: Box<dyn DeviceMapper + Send + Sync>,
497 cfg: ModelConfigMetadata,
498}
499
500impl Model {
501 pub fn new(
502 cfg: &Config,
503 vb: ShardedVarBuilder,
504 is_gptx: bool,
505 normal_loading_metadata: NormalLoadingMetadata,
506 attention_mechanism: AttentionImplementation,
507 ) -> Result<Self> {
508 if let Some(ref quant_cfg) = &cfg.quantization_config {
509 tracing::info!(
510 "Using {} quantization: {}.",
511 quant_cfg.name(),
512 quant_cfg.get_bits_name(&vb)
513 );
514 }
515 let mapper = normal_loading_metadata.mapper;
516 let vb_m = vb.pp("model");
517
518 let embed_tokens = layers::embedding(
519 cfg.vocab_size,
520 cfg.hidden_size,
521 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
522 &cfg.quantization_config,
523 )?;
524 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
525 let mut ropes = HashMap::new();
526 for layer_idx in 0..cfg.num_hidden_layers {
527 let device = mapper
528 .device_for(layer_idx, false)
529 .unwrap_or(&normal_loading_metadata.real_device);
530 ropes.insert(
531 device.location(),
532 Arc::new(RotaryEmbedding::new(
533 cfg.rope_theta as f32,
534 head_dim,
535 cfg.max_position_embeddings,
536 device,
537 is_gptx,
538 vb_m.dtype(),
539 )?),
540 );
541 }
542 let vb_l = vb_m.pp("layers");
543 let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
544 0..cfg.num_hidden_layers,
545 "Loading repeating layers",
546 &normal_loading_metadata.multi_progress,
547 )
548 .par_iter_if_isq(|layer_idx| {
549 let device = mapper
550 .device_for(layer_idx, false)
551 .unwrap_or(&normal_loading_metadata.real_device);
552 let rotary_emb = ropes
553 .get(&device.location())
554 .expect("No RoPE for device location!")
555 .clone();
556 let paged_attn = match &attention_mechanism {
557 AttentionImplementation::Eager => None,
558 AttentionImplementation::PagedAttention => {
559 Some(PagedAttention::new(head_dim, device, None)?)
560 }
561 };
562 let comm = mapper.get_comm_for(layer_idx)?;
563 DecoderLayer::new(
564 rotary_emb.clone(),
565 cfg,
566 vb_l.pp(layer_idx),
567 &*mapper,
568 layer_idx,
569 normal_loading_metadata.loading_isq,
570 paged_attn,
571 &comm,
572 )
573 })?;
574 let norm = RmsNorm::new(
575 cfg.hidden_size,
576 cfg.rms_norm_eps,
577 mapper.set_nm_device(vb_m.pp("norm"), false),
578 )?;
579 let lm_head = if !cfg.tie_word_embeddings {
580 ReplicatedLayer::new(
581 cfg.hidden_size,
582 cfg.vocab_size,
583 &cfg.quantization_config,
584 false,
585 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
586 )?
587 } else {
588 ReplicatedLayer::from_linear(candle_nn::Linear::new(
589 mapper.cast_nm_device(
590 embed_tokens.embeddings(),
591 normal_loading_metadata.loading_isq,
592 )?,
593 None,
594 ))?
595 };
596 Ok(Self {
597 embed_tokens,
598 layers,
599 norm,
600 lm_head,
601 sliding_window: cfg.sliding_window,
602 device: normal_loading_metadata.real_device,
603 cache: EitherCache::Normal(NormalCache::new_sliding(
604 cfg.num_hidden_layers,
605 cfg.max_position_embeddings,
606 cfg.sliding_window,
607 )),
608 max_seq_len: cfg.max_position_embeddings,
609 cfg: ModelConfigMetadata {
610 max_seq_len: cfg.max_position_embeddings,
611 num_layers: cfg.num_hidden_layers,
612 hidden_size: cfg.hidden_size,
613 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
614 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
615 .max(1),
616 sliding_window: cfg.sliding_window,
617 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
618 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
619 },
620 mapper,
621 })
622 }
623
624 pub fn forward(
625 &self,
626 input_ids: &Tensor,
627 seqlen_offsets: &[usize],
628 context_lens: Vec<(usize, usize)>,
629 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
630 flash_params: &FlashParams,
631 ) -> Result<Tensor> {
632 let mut xs = self.embed_tokens.forward(input_ids)?;
633 let cache = &mut self.cache.normal().0;
634 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
635 input_ids,
636 metadata
637 .as_ref()
638 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
639 .unwrap_or(cache as &dyn PastKvLenCache),
640 self.sliding_window,
641 xs.dtype(),
642 self.cfg.num_attn_heads,
643 )?;
644 let attention_mask = attention_mask.filter(|_| {
646 metadata
647 .as_ref()
648 .map(|(_, meta)| meta.is_first_prompt_chunk)
649 .unwrap_or(true)
650 });
651 for (i, layer) in self.layers.iter().enumerate() {
652 xs = self.mapper.map(xs, i)?;
653 xs = layer.forward(
654 &xs,
655 attention_mask
656 .as_ref()
657 .map(|m| m.to_device(xs.device()).unwrap())
658 .as_ref(),
659 seqlen_offsets,
660 &mut cache[i],
661 metadata
662 .as_ref()
663 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
664 flash_params,
665 )?;
666 }
667 let xs = xs.to_device(&self.device)?;
668 let mut xs = xs.apply(&self.norm)?;
669 if let Some(t) = self.lm_head.quantized_act_type() {
670 xs = xs.to_dtype(t)?;
671 }
672 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
673 }
674}
675
676impl IsqModel for Model {
677 fn get_layers(
678 &mut self,
679 ) -> (
680 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
681 &dyn DeviceMapper,
682 ) {
683 let mut tensors = Vec::new();
684 tensors.push((&mut self.lm_head, None));
685 for (i, layer) in self.layers.iter_mut().enumerate() {
686 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
687 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
688 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
689 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
690 tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
691 for expert in &mut layer.block_sparse_moe.experts {
692 tensors.push((&mut expert.w1, Some(i)));
693 tensors.push((&mut expert.w2, Some(i)));
694 tensors.push((&mut expert.w3, Some(i)));
695 }
696 }
697 (tensors, &*self.mapper)
698 }
699
700 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
701 let uvb = UnVarBuilder::new();
702
703 let uvb_m = uvb.pp("model");
704 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
705 uvb_m.pp("norm").add(&self.norm);
706
707 for (layer_idx, layer) in self.layers.iter().enumerate() {
708 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
709 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
710 uvb_l
711 .pp("post_attention_layernorm")
712 .add(&layer.post_attention_layernorm);
713 }
714
715 uvb.to_safetensors()
716 }
717}
718
719impl NormalModel for Model {
720 fn forward(
721 &self,
722 input_ids: &Tensor,
723 seqlen_offsets: &[usize],
724 context_lens: Vec<(usize, usize)>,
725 _position_ids: Vec<usize>,
726 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
727 flash_params: &FlashParams,
728 ) -> Result<Tensor> {
729 self.forward(
730 input_ids,
731 seqlen_offsets,
732 context_lens,
733 metadata,
734 flash_params,
735 )
736 }
737 fn xlora_forward(
738 &self,
739 _input_ids: &Tensor,
740 _input_ids_full: &Tensor,
741 _seqlen_offsets: &[usize],
742 _seqlen_offsets_full: &[usize],
743 _no_kv_cache: bool,
744 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
745 _context_lens: Vec<(usize, usize)>,
746 _position_ids: Vec<usize>,
747 _flash_params: &FlashParams,
748 _flash_params_full: &FlashParams,
749 ) -> Result<Tensor> {
750 unimplemented!()
751 }
752 fn cache(&self) -> &EitherCache {
753 &self.cache
754 }
755 fn cache_mut(&mut self) -> &mut EitherCache {
756 &mut self.cache
757 }
758 fn device(&self) -> &Device {
759 &self.device
760 }
761 fn is_xlora(&self) -> bool {
762 false
763 }
764 fn max_seq_len(&self) -> usize {
765 self.max_seq_len
766 }
767 fn config(&self) -> &ModelConfigMetadata {
768 &self.cfg
769 }
770}
771
772impl AnyMoeBaseModelMixin for Model {}