1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9 ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15 amoe::AnyMoeBaseModelMixin,
16 attention::SdpaParams,
17 device_map::DeviceMapper,
18 layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19 layers_masker::PastKvLenCache,
20 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21 pipeline::{
22 extract_logits,
23 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25 },
26 serde_default_fn,
27 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35 pub(crate) vocab_size: usize,
36 pub(crate) hidden_size: usize,
37 pub(crate) intermediate_size: usize,
38 pub(crate) num_hidden_layers: usize,
39 pub(crate) num_attention_heads: usize,
40 pub(crate) num_key_value_heads: usize,
41 pub(crate) hidden_act: Activation,
42 pub(crate) max_position_embeddings: usize,
43 pub(crate) rms_norm_eps: f64,
44 pub(crate) rope_theta: f64,
45 pub(crate) sliding_window: Option<usize>,
46 pub(crate) num_experts_per_tok: usize,
47 pub(crate) num_local_experts: usize,
48 pub(crate) use_flash_attn: bool,
49 pub(crate) quantization_config: Option<QuantizedConfig>,
50 #[serde(default = "word_emb_default")]
51 pub(crate) tie_word_embeddings: bool,
52}
53
54struct Attention {
55 q_proj: Arc<dyn QuantMethod>,
56 k_proj: Arc<dyn QuantMethod>,
57 v_proj: Arc<dyn QuantMethod>,
58 o_proj: Arc<dyn QuantMethod>,
59 num_heads: usize,
60 num_kv_heads: usize,
61 head_dim: usize,
62 rotary_emb: Arc<RotaryEmbedding>,
63 paged_attn: Option<PagedAttention>,
64 sdpa_params: SdpaParams,
65}
66
67impl Attention {
68 fn new(
69 rotary_emb: Arc<RotaryEmbedding>,
70 cfg: &Config,
71 vb: ShardedVarBuilder,
72 paged_attn: Option<PagedAttention>,
73 comm: &Arc<mistralrs_quant::Comm>,
74 ) -> Result<Self> {
75 let hidden_sz = cfg.hidden_size;
76 let num_heads = cfg.num_attention_heads;
77 let num_kv_heads = cfg.num_key_value_heads;
78 let head_dim = hidden_sz / num_heads;
79 let q_proj = ColumnParallelLayer::new(
80 hidden_sz,
81 num_heads * head_dim,
82 &cfg.quantization_config,
83 false,
84 comm,
85 vb.pp("q_proj"),
86 )?;
87 let kv_shard = mistralrs_quant::compute_kv_shard(
88 cfg.num_key_value_heads,
89 cfg.hidden_size / cfg.num_attention_heads,
90 comm,
91 );
92 let k_proj = ColumnParallelLayer::new_with_shard(
93 hidden_sz,
94 num_kv_heads * head_dim,
95 &cfg.quantization_config,
96 false,
97 comm,
98 kv_shard,
99 vb.pp("k_proj"),
100 )?;
101 let v_proj = ColumnParallelLayer::new_with_shard(
102 hidden_sz,
103 num_kv_heads * head_dim,
104 &cfg.quantization_config,
105 false,
106 comm,
107 kv_shard,
108 vb.pp("v_proj"),
109 )?;
110 let o_proj = RowParallelLayer::new(
111 num_heads * head_dim,
112 hidden_sz,
113 &cfg.quantization_config,
114 false,
115 comm,
116 vb.pp("o_proj"),
117 )?;
118 Ok(Self {
119 q_proj,
120 k_proj,
121 v_proj,
122 o_proj,
123 num_heads: num_heads / comm.world_size(),
124 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
125 head_dim,
126 rotary_emb,
127 paged_attn,
128 sdpa_params: SdpaParams {
129 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
130 cfg.num_key_value_heads,
131 cfg.num_attention_heads,
132 comm,
133 ),
134 use_flash_attn: cfg.use_flash_attn,
135 softcap: None,
136 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
137 sliding_window: cfg.sliding_window,
138 },
139 })
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 fn forward(
144 &self,
145 xs: &Tensor,
146 attention_mask: Option<&Tensor>,
147 seqlen_offsets: &[usize],
148 kv_cache: &mut KvCache,
149 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
150 flash_params: &FlashParams,
151 ) -> Result<Tensor> {
152 let (b_sz, q_len, _) = xs.dims3()?;
153
154 let original_dtype = xs.dtype();
155 let mut xs = xs.clone();
156 if let Some(t) = self.q_proj.quantized_act_type() {
157 xs = xs.to_dtype(t)?;
158 }
159 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
160 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
161 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
162 if self.q_proj.quantized_act_type().is_some() {
163 q = q.to_dtype(original_dtype)?;
164 k = k.to_dtype(original_dtype)?;
165 v = v.to_dtype(original_dtype)?;
166 }
167
168 let (q, k, v) = if q_len != 1 {
169 let q = q
170 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
171 .transpose(1, 2)?;
172 let k = k
173 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174 .transpose(1, 2)?;
175 let v = v
176 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
177 .transpose(1, 2)?;
178 (q, k, v)
179 } else {
180 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
181 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
183 (q, k, v)
184 };
185
186 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
187
188 let mut attn_output = match &self.paged_attn {
189 Some(paged_attn) => match metadata {
190 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
191 &q,
192 &k,
193 &v,
194 attention_mask,
195 Some(key_cache),
196 Some(value_cache),
197 input_metadata,
198 &self.sdpa_params,
199 Some(flash_params),
200 )?,
201 None => {
202 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
205 assert!(attention_mask.is_some());
207 paged_attn.forward(
208 &q,
209 &k,
210 &v,
211 attention_mask,
212 None,
213 None,
214 &input_metadata,
215 &self.sdpa_params,
216 Some(flash_params),
217 )?
218 }
219 },
220 None => {
221 let (k, v) = kv_cache.append(&k, &v)?;
222
223 Sdpa.run_attention(
224 &q,
225 &k,
226 &v,
227 attention_mask,
228 Some(flash_params),
229 &self.sdpa_params,
230 )?
231 }
232 };
233
234 if let Some(t) = self.q_proj.quantized_act_type() {
235 attn_output = attn_output.to_dtype(t)?;
236 }
237 attn_output = if attention_mask.is_some() {
238 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
239 } else {
240 attn_output.reshape((b_sz, q_len, ()))?
241 };
242 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
243 if self.q_proj.quantized_act_type().is_some() {
244 res = res.to_dtype(original_dtype)?;
245 }
246 Ok(res)
247 }
248}
249
250#[derive(Clone)]
251struct BlockSparseTop2MLP {
252 w1: Arc<dyn QuantMethod>,
253 w2: Arc<dyn QuantMethod>,
254 w3: Arc<dyn QuantMethod>,
255 act_fn: Activation,
256}
257
258impl BlockSparseTop2MLP {
259 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
260 let hidden_sz = cfg.hidden_size;
261 let intermediate_sz = cfg.intermediate_size;
262 let w1 = ColumnParallelLayer::new(
263 hidden_sz,
264 intermediate_sz,
265 &cfg.quantization_config,
266 false,
267 comm,
268 vb.pp("w1"),
269 )?;
270 let w2 = RowParallelLayer::new(
271 intermediate_sz,
272 hidden_sz,
273 &cfg.quantization_config,
274 false,
275 comm,
276 vb.pp("w2"),
277 )?;
278 let w3 = ColumnParallelLayer::new(
279 hidden_sz,
280 intermediate_sz,
281 &cfg.quantization_config,
282 false,
283 comm,
284 vb.pp("w3"),
285 )?;
286 Ok(Self {
287 w1,
288 w2,
289 w3,
290 act_fn: cfg.hidden_act,
291 })
292 }
293}
294
295impl Module for BlockSparseTop2MLP {
296 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
297 let original_dtype = xs.dtype();
298 let mut xs = xs.clone();
299 if let Some(t) = self.w1.quantized_act_type() {
300 xs = xs.to_dtype(t)?;
301 }
302 let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
303 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
304 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
305 if self.w1.quantized_act_type().is_some() {
306 res = res.to_dtype(original_dtype)?;
307 }
308 Ok(res)
309 }
310}
311
312#[derive(Clone)]
313struct SparseMoeBlock {
314 gate: Arc<dyn QuantMethod>,
315 experts: Vec<BlockSparseTop2MLP>,
316 num_experts_per_tok: usize,
317}
318
319impl SparseMoeBlock {
320 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
321 let gate = mistralrs_quant::linear_no_bias(
322 cfg.hidden_size,
323 cfg.num_local_experts,
324 &cfg.quantization_config,
325 vb.pp("gate"),
326 )?;
327 let mut experts = Vec::with_capacity(cfg.num_local_experts);
328 let vb = vb.pp("experts");
329 for idx in 0..cfg.num_local_experts {
330 let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
331 experts.push(expert)
332 }
333 Ok(SparseMoeBlock {
334 gate,
335 experts,
336 num_experts_per_tok: cfg.num_experts_per_tok,
337 })
338 }
339}
340
341impl Module for SparseMoeBlock {
342 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
343 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
344 let xs = xs.reshape(((), hidden_dim))?;
345
346 let original_dtype = xs.dtype();
347 let mut xs = xs.clone();
348 if let Some(t) = self.gate.quantized_act_type() {
349 xs = xs.to_dtype(t)?;
350 }
351 let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
352 if self.gate.quantized_act_type().is_some() {
353 router_logits = router_logits.to_dtype(original_dtype)?;
354 }
355
356 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
357
358 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
361
362 let mut top_x = vec![vec![]; self.experts.len()];
365 let mut selected_rws = vec![vec![]; self.experts.len()];
366 for (row_idx, rw) in routing_weights.iter().enumerate() {
367 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
368 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
369 let mut sum_routing_weights = 0f32;
370 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
371 let expert_idx = expert_idx as usize;
372 let routing_weight = rw[expert_idx];
373 sum_routing_weights += routing_weight;
374 top_x[expert_idx].push(row_idx as u32);
375 }
376 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
377 let expert_idx = expert_idx as usize;
378 let routing_weight = rw[expert_idx];
379 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
380 }
381 }
382
383 let mut ys = xs.zeros_like()?;
387 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
388 let top_x = &top_x[expert_idx];
389 if top_x.is_empty() {
390 continue;
391 }
392 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
393 let selected_rws =
394 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
395 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
399 let current_hidden_states = expert_layer.forward(¤t_state)?;
401 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
402 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
403 }
404
405 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
406 Ok(ys)
407 }
408}
409
410struct DecoderLayer {
411 self_attn: Attention,
412 block_sparse_moe: SparseMoeBlock,
413 input_layernorm: RmsNorm,
414 post_attention_layernorm: RmsNorm,
415}
416
417impl DecoderLayer {
418 #[allow(clippy::too_many_arguments)]
419 fn new(
420 rotary_emb: Arc<RotaryEmbedding>,
421 cfg: &Config,
422 vb: ShardedVarBuilder,
423 mapper: &dyn DeviceMapper,
424 layer_idx: usize,
425 loading_isq: bool,
426 paged_attn: Option<PagedAttention>,
427 comm: &Arc<mistralrs_quant::Comm>,
428 ) -> Result<Self> {
429 let self_attn = Attention::new(
430 rotary_emb,
431 cfg,
432 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
433 paged_attn,
434 comm,
435 )?;
436 let block_sparse_moe = SparseMoeBlock::new(
437 cfg,
438 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
439 comm,
440 )?;
441 let input_layernorm = RmsNorm::new(
442 cfg.hidden_size,
443 cfg.rms_norm_eps,
444 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
445 )?;
446 let post_attention_layernorm = RmsNorm::new(
447 cfg.hidden_size,
448 cfg.rms_norm_eps,
449 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
450 )?;
451 Ok(Self {
452 self_attn,
453 block_sparse_moe,
454 input_layernorm,
455 post_attention_layernorm,
456 })
457 }
458
459 #[allow(clippy::too_many_arguments)]
460 fn forward(
461 &self,
462 xs: &Tensor,
463 attention_mask: Option<&Tensor>,
464 seqlen_offsets: &[usize],
465 kv_cache: &mut KvCache,
466 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
467 flash_params: &FlashParams,
468 ) -> Result<Tensor> {
469 let residual = xs;
470 let xs = self.input_layernorm.forward(xs)?;
471 let xs = self.self_attn.forward(
472 &xs,
473 attention_mask,
474 seqlen_offsets,
475 kv_cache,
476 metadata,
477 flash_params,
478 )?;
479 let xs = (xs + residual)?;
480 let residual = &xs;
481 let xs = xs
482 .apply(&self.post_attention_layernorm)?
483 .apply(&self.block_sparse_moe)?
484 .to_dtype(residual.dtype())?;
485 residual + xs
486 }
487}
488
489pub struct Model {
490 embed_tokens: candle_nn::Embedding,
491 layers: Vec<DecoderLayer>,
492 norm: RmsNorm,
493 lm_head: Arc<dyn QuantMethod>,
494 sliding_window: Option<usize>,
495 device: Device,
496 cache: EitherCache,
497 max_seq_len: usize,
498 mapper: Box<dyn DeviceMapper + Send + Sync>,
499 cfg: ModelConfigMetadata,
500}
501
502impl Model {
503 pub fn new(
504 cfg: &Config,
505 vb: ShardedVarBuilder,
506 is_gptx: bool,
507 normal_loading_metadata: NormalLoadingMetadata,
508 attention_mechanism: AttentionImplementation,
509 ) -> Result<Self> {
510 if let Some(ref quant_cfg) = &cfg.quantization_config {
511 tracing::info!(
512 "Using {} quantization: {}.",
513 quant_cfg.quant_method.to_string(),
514 quant_cfg.get_bits_name(&vb)
515 );
516 }
517 let mapper = normal_loading_metadata.mapper;
518 let vb_m = vb.pp("model");
519
520 let embed_tokens = layers::embedding(
521 cfg.vocab_size,
522 cfg.hidden_size,
523 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
524 )?;
525 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
526 let mut ropes = HashMap::new();
527 for layer_idx in 0..cfg.num_hidden_layers {
528 let device = mapper
529 .device_for(layer_idx, false)
530 .unwrap_or(&normal_loading_metadata.real_device);
531 ropes.insert(
532 device.location(),
533 Arc::new(RotaryEmbedding::new(
534 cfg.rope_theta as f32,
535 head_dim,
536 cfg.max_position_embeddings,
537 device,
538 is_gptx,
539 vb_m.dtype(),
540 )?),
541 );
542 }
543 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
544 let vb_l = vb_m.pp("layers");
545 for layer_idx in NiceProgressBar::<_, 'b'>(
546 0..cfg.num_hidden_layers,
547 "Loading repeating layers",
548 &normal_loading_metadata.multi_progress,
549 ) {
550 let device = mapper
551 .device_for(layer_idx, false)
552 .unwrap_or(&normal_loading_metadata.real_device);
553 let rotary_emb = ropes
554 .get(&device.location())
555 .expect("No RoPE for device location!")
556 .clone();
557 let paged_attn = match &attention_mechanism {
558 AttentionImplementation::Eager => None,
559 AttentionImplementation::PagedAttention => {
560 Some(PagedAttention::new(head_dim, device, None)?)
561 }
562 };
563 let comm = mapper.get_comm_for(layer_idx)?;
564 let layer = DecoderLayer::new(
565 rotary_emb.clone(),
566 cfg,
567 vb_l.pp(layer_idx),
568 &*mapper,
569 layer_idx,
570 normal_loading_metadata.loading_isq,
571 paged_attn,
572 &comm,
573 )?;
574 layers.push(layer)
575 }
576 let norm = RmsNorm::new(
577 cfg.hidden_size,
578 cfg.rms_norm_eps,
579 mapper.set_nm_device(vb_m.pp("norm"), false),
580 )?;
581 let lm_head = if !cfg.tie_word_embeddings {
582 ReplicatedLayer::new(
583 cfg.hidden_size,
584 cfg.vocab_size,
585 &None,
586 false,
587 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
588 )?
589 } else {
590 ReplicatedLayer::from_linear(candle_nn::Linear::new(
591 mapper.cast_nm_device(
592 embed_tokens.embeddings(),
593 normal_loading_metadata.loading_isq,
594 )?,
595 None,
596 ))?
597 };
598 Ok(Self {
599 embed_tokens,
600 layers,
601 norm,
602 lm_head,
603 sliding_window: cfg.sliding_window,
604 device: normal_loading_metadata.real_device,
605 cache: EitherCache::Normal(NormalCache::new_sliding(
606 cfg.num_hidden_layers,
607 cfg.max_position_embeddings,
608 cfg.sliding_window,
609 )),
610 max_seq_len: cfg.max_position_embeddings,
611 cfg: ModelConfigMetadata {
612 max_seq_len: cfg.max_position_embeddings,
613 num_layers: cfg.num_hidden_layers,
614 hidden_size: cfg.hidden_size,
615 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
616 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
617 .max(1),
618 sliding_window: cfg.sliding_window,
619 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
620 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
621 },
622 mapper,
623 })
624 }
625
626 pub fn forward(
627 &self,
628 input_ids: &Tensor,
629 seqlen_offsets: &[usize],
630 context_lens: Vec<(usize, usize)>,
631 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
632 flash_params: &FlashParams,
633 ) -> Result<Tensor> {
634 let mut xs = self.embed_tokens.forward(input_ids)?;
635 let cache = &mut self.cache.normal().0;
636 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
637 input_ids,
638 metadata
639 .as_ref()
640 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
641 .unwrap_or(cache as &dyn PastKvLenCache),
642 self.sliding_window,
643 xs.dtype(),
644 self.cfg.num_attn_heads,
645 )?;
646 let attention_mask = attention_mask.filter(|_| {
648 metadata
649 .as_ref()
650 .map(|(_, meta)| meta.is_first_prompt_chunk)
651 .unwrap_or(true)
652 });
653 for (i, layer) in self.layers.iter().enumerate() {
654 xs = self.mapper.map(xs, i)?;
655 xs = layer.forward(
656 &xs,
657 attention_mask
658 .as_ref()
659 .map(|m| m.to_device(xs.device()).unwrap())
660 .as_ref(),
661 seqlen_offsets,
662 &mut cache[i],
663 metadata
664 .as_ref()
665 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
666 flash_params,
667 )?;
668 }
669 let xs = xs.to_device(&self.device)?;
670 let mut xs = xs.apply(&self.norm)?;
671 if let Some(t) = self.lm_head.quantized_act_type() {
672 xs = xs.to_dtype(t)?;
673 }
674 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
675 }
676}
677
678impl IsqModel for Model {
679 fn get_layers(
680 &mut self,
681 ) -> (
682 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
683 &dyn DeviceMapper,
684 ) {
685 let mut tensors = Vec::new();
686 tensors.push((&mut self.lm_head, None));
687 for (i, layer) in self.layers.iter_mut().enumerate() {
688 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
689 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
690 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
691 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
692 tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
693 for expert in &mut layer.block_sparse_moe.experts {
694 tensors.push((&mut expert.w1, Some(i)));
695 tensors.push((&mut expert.w2, Some(i)));
696 tensors.push((&mut expert.w3, Some(i)));
697 }
698 }
699 (tensors, &*self.mapper)
700 }
701
702 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
703 let uvb = UnVarBuilder::new();
704
705 let uvb_m = uvb.pp("model");
706 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
707 uvb_m.pp("norm").add(&self.norm);
708
709 for (layer_idx, layer) in self.layers.iter().enumerate() {
710 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
711 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
712 uvb_l
713 .pp("post_attention_layernorm")
714 .add(&layer.post_attention_layernorm);
715 }
716
717 uvb.to_safetensors()
718 }
719}
720
721impl NormalModel for Model {
722 fn forward(
723 &self,
724 input_ids: &Tensor,
725 seqlen_offsets: &[usize],
726 context_lens: Vec<(usize, usize)>,
727 _position_ids: Vec<usize>,
728 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
729 flash_params: &FlashParams,
730 ) -> Result<Tensor> {
731 self.forward(
732 input_ids,
733 seqlen_offsets,
734 context_lens,
735 metadata,
736 flash_params,
737 )
738 }
739 fn xlora_forward(
740 &self,
741 _input_ids: &Tensor,
742 _input_ids_full: &Tensor,
743 _seqlen_offsets: &[usize],
744 _seqlen_offsets_full: &[usize],
745 _no_kv_cache: bool,
746 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
747 _context_lens: Vec<(usize, usize)>,
748 _position_ids: Vec<usize>,
749 _flash_params: &FlashParams,
750 _flash_params_full: &FlashParams,
751 ) -> Result<Tensor> {
752 unimplemented!()
753 }
754 fn cache(&self) -> &EitherCache {
755 &self.cache
756 }
757 fn cache_mut(&mut self) -> &mut EitherCache {
758 &mut self.cache
759 }
760 fn device(&self) -> &Device {
761 &self.device
762 }
763 fn is_xlora(&self) -> bool {
764 false
765 }
766 fn max_seq_len(&self) -> usize {
767 self.max_seq_len
768 }
769 fn config(&self) -> &ModelConfigMetadata {
770 &self.cfg
771 }
772}
773
774impl AnyMoeBaseModelMixin for Model {}