1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9 ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15 amoe::AnyMoeBaseModelMixin,
16 attention::SdpaParams,
17 device_map::DeviceMapper,
18 layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19 layers_masker::PastKvLenCache,
20 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21 pipeline::{
22 extract_logits,
23 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25 },
26 serde_default_fn,
27 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35 pub(crate) vocab_size: usize,
36 pub(crate) hidden_size: usize,
37 pub(crate) intermediate_size: usize,
38 pub(crate) num_hidden_layers: usize,
39 pub(crate) num_attention_heads: usize,
40 pub(crate) num_key_value_heads: usize,
41 pub(crate) hidden_act: Activation,
42 pub(crate) max_position_embeddings: usize,
43 pub(crate) rms_norm_eps: f64,
44 pub(crate) rope_theta: f64,
45 pub(crate) sliding_window: Option<usize>,
46 pub(crate) num_experts_per_tok: usize,
47 pub(crate) num_local_experts: usize,
48 pub(crate) use_flash_attn: bool,
49 pub(crate) quantization_config: Option<QuantizedConfig>,
50 #[serde(default = "word_emb_default")]
51 pub(crate) tie_word_embeddings: bool,
52}
53
54struct Attention {
55 q_proj: Arc<dyn QuantMethod>,
56 k_proj: Arc<dyn QuantMethod>,
57 v_proj: Arc<dyn QuantMethod>,
58 o_proj: Arc<dyn QuantMethod>,
59 num_heads: usize,
60 num_kv_heads: usize,
61 head_dim: usize,
62 rotary_emb: Arc<RotaryEmbedding>,
63 paged_attn: Option<PagedAttention>,
64 sdpa_params: SdpaParams,
65}
66
67impl Attention {
68 fn new(
69 rotary_emb: Arc<RotaryEmbedding>,
70 cfg: &Config,
71 vb: ShardedVarBuilder,
72 paged_attn: Option<PagedAttention>,
73 comm: &Arc<mistralrs_quant::Comm>,
74 ) -> Result<Self> {
75 let hidden_sz = cfg.hidden_size;
76 let num_heads = cfg.num_attention_heads;
77 let num_kv_heads = cfg.num_key_value_heads;
78 let head_dim = hidden_sz / num_heads;
79 let q_proj = ColumnParallelLayer::new(
80 hidden_sz,
81 num_heads * head_dim,
82 &cfg.quantization_config,
83 false,
84 comm,
85 vb.pp("q_proj"),
86 )?;
87 let kv_shard = mistralrs_quant::compute_kv_shard(
88 cfg.num_key_value_heads,
89 cfg.hidden_size / cfg.num_attention_heads,
90 comm,
91 );
92 let k_proj = ColumnParallelLayer::new_with_shard(
93 hidden_sz,
94 num_kv_heads * head_dim,
95 &cfg.quantization_config,
96 false,
97 comm,
98 kv_shard,
99 vb.pp("k_proj"),
100 )?;
101 let v_proj = ColumnParallelLayer::new_with_shard(
102 hidden_sz,
103 num_kv_heads * head_dim,
104 &cfg.quantization_config,
105 false,
106 comm,
107 kv_shard,
108 vb.pp("v_proj"),
109 )?;
110 let o_proj = RowParallelLayer::new(
111 num_heads * head_dim,
112 hidden_sz,
113 &cfg.quantization_config,
114 false,
115 comm,
116 vb.pp("o_proj"),
117 )?;
118 Ok(Self {
119 q_proj,
120 k_proj,
121 v_proj,
122 o_proj,
123 num_heads: num_heads / comm.world_size(),
124 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
125 head_dim,
126 rotary_emb,
127 paged_attn,
128 sdpa_params: SdpaParams {
129 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
130 cfg.num_key_value_heads,
131 cfg.num_attention_heads,
132 comm,
133 ),
134 use_flash_attn: cfg.use_flash_attn,
135 softcap: None,
136 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
137 sliding_window: cfg.sliding_window,
138 },
139 })
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 fn forward(
144 &self,
145 xs: &Tensor,
146 attention_mask: Option<&Tensor>,
147 seqlen_offsets: &[usize],
148 kv_cache: &mut KvCache,
149 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
150 flash_params: &FlashParams,
151 ) -> Result<Tensor> {
152 let (b_sz, q_len, _) = xs.dims3()?;
153
154 let original_dtype = xs.dtype();
155 let mut xs = xs.clone();
156 if let Some(t) = self.q_proj.quantized_act_type() {
157 xs = xs.to_dtype(t)?;
158 }
159 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
160 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
161 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
162 if self.q_proj.quantized_act_type().is_some() {
163 q = q.to_dtype(original_dtype)?;
164 k = k.to_dtype(original_dtype)?;
165 v = v.to_dtype(original_dtype)?;
166 }
167
168 let (q, k, v) = if q_len != 1 {
169 let q = q
170 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
171 .transpose(1, 2)?;
172 let k = k
173 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174 .transpose(1, 2)?;
175 let v = v
176 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
177 .transpose(1, 2)?;
178 (q, k, v)
179 } else {
180 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
181 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
183 (q, k, v)
184 };
185
186 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
187
188 let mut attn_output = match &self.paged_attn {
189 Some(paged_attn) => match metadata {
190 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
191 &q,
192 &k,
193 &v,
194 attention_mask,
195 Some(key_cache),
196 Some(value_cache),
197 input_metadata,
198 &self.sdpa_params,
199 Some(flash_params),
200 )?,
201 None => {
202 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
205 assert!(attention_mask.is_some());
207 paged_attn.forward(
208 &q,
209 &k,
210 &v,
211 attention_mask,
212 None,
213 None,
214 &input_metadata,
215 &self.sdpa_params,
216 Some(flash_params),
217 )?
218 }
219 },
220 None => {
221 let (k, v) = kv_cache.append(&k, &v)?;
222
223 Sdpa.run_attention(
224 &q,
225 &k,
226 &v,
227 attention_mask,
228 Some(flash_params),
229 &self.sdpa_params,
230 )?
231 }
232 };
233
234 if let Some(t) = self.q_proj.quantized_act_type() {
235 attn_output = attn_output.to_dtype(t)?;
236 }
237 attn_output = if attention_mask.is_some() {
238 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
239 } else {
240 attn_output.reshape((b_sz, q_len, ()))?
241 };
242 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
243 if self.q_proj.quantized_act_type().is_some() {
244 res = res.to_dtype(original_dtype)?;
245 }
246 Ok(res)
247 }
248}
249
250#[derive(Clone)]
251struct BlockSparseTop2MLP {
252 w1: Arc<dyn QuantMethod>,
253 w2: Arc<dyn QuantMethod>,
254 w3: Arc<dyn QuantMethod>,
255 act_fn: Activation,
256}
257
258impl BlockSparseTop2MLP {
259 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
260 let hidden_sz = cfg.hidden_size;
261 let intermediate_sz = cfg.intermediate_size;
262 let w1 = ColumnParallelLayer::new(
263 hidden_sz,
264 intermediate_sz,
265 &cfg.quantization_config,
266 false,
267 comm,
268 vb.pp("w1"),
269 )?;
270 let w2 = RowParallelLayer::new(
271 intermediate_sz,
272 hidden_sz,
273 &cfg.quantization_config,
274 false,
275 comm,
276 vb.pp("w2"),
277 )?;
278 let w3 = ColumnParallelLayer::new(
279 hidden_sz,
280 intermediate_sz,
281 &cfg.quantization_config,
282 false,
283 comm,
284 vb.pp("w3"),
285 )?;
286 Ok(Self {
287 w1,
288 w2,
289 w3,
290 act_fn: cfg.hidden_act,
291 })
292 }
293}
294
295impl Module for BlockSparseTop2MLP {
296 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
297 let original_dtype = xs.dtype();
298 let mut xs = xs.clone();
299 if let Some(t) = self.w1.quantized_act_type() {
300 xs = xs.to_dtype(t)?;
301 }
302 let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
303 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
304 let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
305 if self.w1.quantized_act_type().is_some() {
306 res = res.to_dtype(original_dtype)?;
307 }
308 Ok(res)
309 }
310}
311
312#[derive(Clone)]
313struct SparseMoeBlock {
314 gate: Arc<dyn QuantMethod>,
315 experts: Vec<BlockSparseTop2MLP>,
316 num_experts_per_tok: usize,
317}
318
319impl SparseMoeBlock {
320 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
321 let gate = mistralrs_quant::linear_no_bias(
322 cfg.hidden_size,
323 cfg.num_local_experts,
324 &cfg.quantization_config,
325 vb.pp("gate"),
326 )?;
327 let mut experts = Vec::with_capacity(cfg.num_local_experts);
328 let vb = vb.pp("experts");
329 for idx in 0..cfg.num_local_experts {
330 let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
331 experts.push(expert)
332 }
333 Ok(SparseMoeBlock {
334 gate,
335 experts,
336 num_experts_per_tok: cfg.num_experts_per_tok,
337 })
338 }
339}
340
341impl Module for SparseMoeBlock {
342 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
343 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
344 let xs = xs.reshape(((), hidden_dim))?;
345
346 let original_dtype = xs.dtype();
347 let mut xs = xs.clone();
348 if let Some(t) = self.gate.quantized_act_type() {
349 xs = xs.to_dtype(t)?;
350 }
351 let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
352 if self.gate.quantized_act_type().is_some() {
353 router_logits = router_logits.to_dtype(original_dtype)?;
354 }
355
356 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
357
358 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
361
362 let mut top_x = vec![vec![]; self.experts.len()];
365 let mut selected_rws = vec![vec![]; self.experts.len()];
366 for (row_idx, rw) in routing_weights.iter().enumerate() {
367 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
368 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
369 let mut sum_routing_weights = 0f32;
370 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
371 let expert_idx = expert_idx as usize;
372 let routing_weight = rw[expert_idx];
373 sum_routing_weights += routing_weight;
374 top_x[expert_idx].push(row_idx as u32);
375 }
376 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
377 let expert_idx = expert_idx as usize;
378 let routing_weight = rw[expert_idx];
379 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
380 }
381 }
382
383 let mut ys = xs.zeros_like()?;
387 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
388 let top_x = &top_x[expert_idx];
389 if top_x.is_empty() {
390 continue;
391 }
392 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
393 let selected_rws =
394 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
395 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
399 let current_hidden_states = expert_layer.forward(¤t_state)?;
401 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
402 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
403 }
404
405 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
406 Ok(ys)
407 }
408}
409
410struct DecoderLayer {
411 self_attn: Attention,
412 block_sparse_moe: SparseMoeBlock,
413 input_layernorm: RmsNorm,
414 post_attention_layernorm: RmsNorm,
415}
416
417impl DecoderLayer {
418 #[allow(clippy::too_many_arguments)]
419 fn new(
420 rotary_emb: Arc<RotaryEmbedding>,
421 cfg: &Config,
422 vb: ShardedVarBuilder,
423 mapper: &dyn DeviceMapper,
424 layer_idx: usize,
425 loading_isq: bool,
426 paged_attn: Option<PagedAttention>,
427 comm: &Arc<mistralrs_quant::Comm>,
428 ) -> Result<Self> {
429 let self_attn = Attention::new(
430 rotary_emb,
431 cfg,
432 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
433 paged_attn,
434 comm,
435 )?;
436 let block_sparse_moe = SparseMoeBlock::new(
437 cfg,
438 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
439 comm,
440 )?;
441 let input_layernorm = RmsNorm::new(
442 cfg.hidden_size,
443 cfg.rms_norm_eps,
444 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
445 )?;
446 let post_attention_layernorm = RmsNorm::new(
447 cfg.hidden_size,
448 cfg.rms_norm_eps,
449 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
450 )?;
451 Ok(Self {
452 self_attn,
453 block_sparse_moe,
454 input_layernorm,
455 post_attention_layernorm,
456 })
457 }
458
459 #[allow(clippy::too_many_arguments)]
460 fn forward(
461 &self,
462 xs: &Tensor,
463 attention_mask: Option<&Tensor>,
464 seqlen_offsets: &[usize],
465 kv_cache: &mut KvCache,
466 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
467 flash_params: &FlashParams,
468 ) -> Result<Tensor> {
469 let residual = xs;
470 let xs = self.input_layernorm.forward(xs)?;
471 let xs = self.self_attn.forward(
472 &xs,
473 attention_mask,
474 seqlen_offsets,
475 kv_cache,
476 metadata,
477 flash_params,
478 )?;
479 let xs = (xs + residual)?;
480 let residual = &xs;
481 let xs = xs
482 .apply(&self.post_attention_layernorm)?
483 .apply(&self.block_sparse_moe)?
484 .to_dtype(residual.dtype())?;
485 residual + xs
486 }
487}
488
489pub struct Model {
490 embed_tokens: candle_nn::Embedding,
491 layers: Vec<DecoderLayer>,
492 norm: RmsNorm,
493 lm_head: Arc<dyn QuantMethod>,
494 sliding_window: Option<usize>,
495 device: Device,
496 cache: EitherCache,
497 max_seq_len: usize,
498 mapper: Box<dyn DeviceMapper + Send + Sync>,
499 cfg: ModelConfigMetadata,
500}
501
502impl Model {
503 pub fn new(
504 cfg: &Config,
505 vb: ShardedVarBuilder,
506 is_gptx: bool,
507 normal_loading_metadata: NormalLoadingMetadata,
508 attention_mechanism: AttentionImplementation,
509 ) -> Result<Self> {
510 if let Some(ref quant_cfg) = &cfg.quantization_config {
511 tracing::info!(
512 "Using {} quantization: {}.",
513 quant_cfg.name(),
514 quant_cfg.get_bits_name(&vb)
515 );
516 }
517 let mapper = normal_loading_metadata.mapper;
518 let vb_m = vb.pp("model");
519
520 let embed_tokens = layers::embedding(
521 cfg.vocab_size,
522 cfg.hidden_size,
523 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
524 &cfg.quantization_config,
525 )?;
526 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
527 let mut ropes = HashMap::new();
528 for layer_idx in 0..cfg.num_hidden_layers {
529 let device = mapper
530 .device_for(layer_idx, false)
531 .unwrap_or(&normal_loading_metadata.real_device);
532 ropes.insert(
533 device.location(),
534 Arc::new(RotaryEmbedding::new(
535 cfg.rope_theta as f32,
536 head_dim,
537 cfg.max_position_embeddings,
538 device,
539 is_gptx,
540 vb_m.dtype(),
541 )?),
542 );
543 }
544 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
545 let vb_l = vb_m.pp("layers");
546 for layer_idx in NiceProgressBar::<_, 'b'>(
547 0..cfg.num_hidden_layers,
548 "Loading repeating layers",
549 &normal_loading_metadata.multi_progress,
550 ) {
551 let device = mapper
552 .device_for(layer_idx, false)
553 .unwrap_or(&normal_loading_metadata.real_device);
554 let rotary_emb = ropes
555 .get(&device.location())
556 .expect("No RoPE for device location!")
557 .clone();
558 let paged_attn = match &attention_mechanism {
559 AttentionImplementation::Eager => None,
560 AttentionImplementation::PagedAttention => {
561 Some(PagedAttention::new(head_dim, device, None)?)
562 }
563 };
564 let comm = mapper.get_comm_for(layer_idx)?;
565 let layer = DecoderLayer::new(
566 rotary_emb.clone(),
567 cfg,
568 vb_l.pp(layer_idx),
569 &*mapper,
570 layer_idx,
571 normal_loading_metadata.loading_isq,
572 paged_attn,
573 &comm,
574 )?;
575 layers.push(layer)
576 }
577 let norm = RmsNorm::new(
578 cfg.hidden_size,
579 cfg.rms_norm_eps,
580 mapper.set_nm_device(vb_m.pp("norm"), false),
581 )?;
582 let lm_head = if !cfg.tie_word_embeddings {
583 ReplicatedLayer::new(
584 cfg.hidden_size,
585 cfg.vocab_size,
586 &None,
587 false,
588 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
589 )?
590 } else {
591 ReplicatedLayer::from_linear(candle_nn::Linear::new(
592 mapper.cast_nm_device(
593 embed_tokens.embeddings(),
594 normal_loading_metadata.loading_isq,
595 )?,
596 None,
597 ))?
598 };
599 Ok(Self {
600 embed_tokens,
601 layers,
602 norm,
603 lm_head,
604 sliding_window: cfg.sliding_window,
605 device: normal_loading_metadata.real_device,
606 cache: EitherCache::Normal(NormalCache::new_sliding(
607 cfg.num_hidden_layers,
608 cfg.max_position_embeddings,
609 cfg.sliding_window,
610 )),
611 max_seq_len: cfg.max_position_embeddings,
612 cfg: ModelConfigMetadata {
613 max_seq_len: cfg.max_position_embeddings,
614 num_layers: cfg.num_hidden_layers,
615 hidden_size: cfg.hidden_size,
616 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
617 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
618 .max(1),
619 sliding_window: cfg.sliding_window,
620 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
621 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
622 },
623 mapper,
624 })
625 }
626
627 pub fn forward(
628 &self,
629 input_ids: &Tensor,
630 seqlen_offsets: &[usize],
631 context_lens: Vec<(usize, usize)>,
632 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
633 flash_params: &FlashParams,
634 ) -> Result<Tensor> {
635 let mut xs = self.embed_tokens.forward(input_ids)?;
636 let cache = &mut self.cache.normal().0;
637 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
638 input_ids,
639 metadata
640 .as_ref()
641 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
642 .unwrap_or(cache as &dyn PastKvLenCache),
643 self.sliding_window,
644 xs.dtype(),
645 self.cfg.num_attn_heads,
646 )?;
647 let attention_mask = attention_mask.filter(|_| {
649 metadata
650 .as_ref()
651 .map(|(_, meta)| meta.is_first_prompt_chunk)
652 .unwrap_or(true)
653 });
654 for (i, layer) in self.layers.iter().enumerate() {
655 xs = self.mapper.map(xs, i)?;
656 xs = layer.forward(
657 &xs,
658 attention_mask
659 .as_ref()
660 .map(|m| m.to_device(xs.device()).unwrap())
661 .as_ref(),
662 seqlen_offsets,
663 &mut cache[i],
664 metadata
665 .as_ref()
666 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
667 flash_params,
668 )?;
669 }
670 let xs = xs.to_device(&self.device)?;
671 let mut xs = xs.apply(&self.norm)?;
672 if let Some(t) = self.lm_head.quantized_act_type() {
673 xs = xs.to_dtype(t)?;
674 }
675 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
676 }
677}
678
679impl IsqModel for Model {
680 fn get_layers(
681 &mut self,
682 ) -> (
683 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
684 &dyn DeviceMapper,
685 ) {
686 let mut tensors = Vec::new();
687 tensors.push((&mut self.lm_head, None));
688 for (i, layer) in self.layers.iter_mut().enumerate() {
689 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
690 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
691 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
692 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
693 tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
694 for expert in &mut layer.block_sparse_moe.experts {
695 tensors.push((&mut expert.w1, Some(i)));
696 tensors.push((&mut expert.w2, Some(i)));
697 tensors.push((&mut expert.w3, Some(i)));
698 }
699 }
700 (tensors, &*self.mapper)
701 }
702
703 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
704 let uvb = UnVarBuilder::new();
705
706 let uvb_m = uvb.pp("model");
707 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
708 uvb_m.pp("norm").add(&self.norm);
709
710 for (layer_idx, layer) in self.layers.iter().enumerate() {
711 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
712 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
713 uvb_l
714 .pp("post_attention_layernorm")
715 .add(&layer.post_attention_layernorm);
716 }
717
718 uvb.to_safetensors()
719 }
720}
721
722impl NormalModel for Model {
723 fn forward(
724 &self,
725 input_ids: &Tensor,
726 seqlen_offsets: &[usize],
727 context_lens: Vec<(usize, usize)>,
728 _position_ids: Vec<usize>,
729 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
730 flash_params: &FlashParams,
731 ) -> Result<Tensor> {
732 self.forward(
733 input_ids,
734 seqlen_offsets,
735 context_lens,
736 metadata,
737 flash_params,
738 )
739 }
740 fn xlora_forward(
741 &self,
742 _input_ids: &Tensor,
743 _input_ids_full: &Tensor,
744 _seqlen_offsets: &[usize],
745 _seqlen_offsets_full: &[usize],
746 _no_kv_cache: bool,
747 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
748 _context_lens: Vec<(usize, usize)>,
749 _position_ids: Vec<usize>,
750 _flash_params: &FlashParams,
751 _flash_params_full: &FlashParams,
752 ) -> Result<Tensor> {
753 unimplemented!()
754 }
755 fn cache(&self) -> &EitherCache {
756 &self.cache
757 }
758 fn cache_mut(&mut self) -> &mut EitherCache {
759 &mut self.cache
760 }
761 fn device(&self) -> &Device {
762 &self.device
763 }
764 fn is_xlora(&self) -> bool {
765 false
766 }
767 fn max_seq_len(&self) -> usize {
768 self.max_seq_len
769 }
770 fn config(&self) -> &ModelConfigMetadata {
771 &self.cfg
772 }
773}
774
775impl AnyMoeBaseModelMixin for Model {}