1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{Device, Module, Result, Tensor};
4use mistralrs_quant::{
5 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
6};
7
8use crate::{
9 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
10 attention::SdpaParams,
11 device_map::DeviceMapper,
12 get_delta_from_lora_ab,
13 layers::{
14 embedding, CausalMasker, Gemma3RotaryEmbedding, MatMul, Mlp, RmsNorm, RotaryEmbedding,
15 ScaledEmbedding, Sdpa,
16 },
17 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 EitherCache, IsqModel, KvCache, NormalCache, NormalCacheType, NormalLoadingMetadata,
22 VisionModel,
23 },
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27use super::config::Gemma3TextConfig;
28
29macro_rules! is_sliding {
30 ($layer_idx:expr, $cfg:expr) => {
31 ($layer_idx + 1) % $cfg.sliding_window_pattern != 0
32 };
33}
34
35struct Attention {
36 q_proj: Arc<dyn QuantMethod>,
37 k_proj: Arc<dyn QuantMethod>,
38 v_proj: Arc<dyn QuantMethod>,
39 o_proj: Arc<dyn QuantMethod>,
40 num_heads: usize,
41 num_kv_heads: usize,
42 head_dim: usize,
43 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
44 rotary_emb_local: Arc<RotaryEmbedding>,
45 use_sliding_window: bool,
46 paged_attn: Option<PagedAttention>,
47 sdpa_params: SdpaParams,
48 q_norm: RmsNorm,
49 k_norm: RmsNorm,
50}
51
52impl Attention {
53 #[allow(clippy::too_many_arguments)]
54 fn new(
55 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
56 rotary_emb_local: Arc<RotaryEmbedding>,
57 cfg: &Gemma3TextConfig,
58 layer_idx: usize,
59 mapper: &dyn DeviceMapper,
60 vb: ShardedVarBuilder,
61 paged_attn: Option<PagedAttention>,
62 comm: &Arc<mistralrs_quant::Comm>,
63 ) -> Result<Self> {
64 let hidden_sz = cfg.hidden_size;
65 let num_heads = cfg.num_attention_heads;
66 let num_kv_heads = cfg.num_key_value_heads;
67 let head_dim = cfg.head_dim;
68 let bias = cfg.attention_bias;
69 let q_proj = ColumnParallelLayer::new(
70 hidden_sz,
71 num_heads * head_dim,
72 &cfg.quantization_config,
73 bias,
74 comm,
75 vb.pp("q_proj"),
76 )?;
77 let kv_shard = mistralrs_quant::compute_kv_shard(
78 cfg.num_key_value_heads,
79 cfg.hidden_size / cfg.num_attention_heads,
80 comm,
81 );
82 let k_proj = ColumnParallelLayer::new_with_shard(
83 hidden_sz,
84 num_kv_heads * head_dim,
85 &cfg.quantization_config,
86 bias,
87 comm,
88 kv_shard,
89 vb.pp("k_proj"),
90 )?;
91 let v_proj = ColumnParallelLayer::new_with_shard(
92 hidden_sz,
93 num_kv_heads * head_dim,
94 &cfg.quantization_config,
95 bias,
96 comm,
97 kv_shard,
98 vb.pp("v_proj"),
99 )?;
100 let o_proj = RowParallelLayer::new(
101 num_heads * head_dim,
102 hidden_sz,
103 &cfg.quantization_config,
104 bias,
105 comm,
106 vb.pp("o_proj"),
107 )?;
108 let sliding_window = if is_sliding!(layer_idx, cfg) {
109 Some(cfg.sliding_window)
110 } else {
111 None
112 };
113
114 let q_norm = RmsNorm::new_gemma(
115 cfg.head_dim,
116 cfg.rms_norm_eps,
117 mapper.set_device(layer_idx, vb.pp("q_norm"), false),
118 )?;
119 let k_norm = RmsNorm::new_gemma(
120 cfg.head_dim,
121 cfg.rms_norm_eps,
122 mapper.set_device(layer_idx, vb.pp("k_norm"), false),
123 )?;
124 Ok(Self {
125 q_proj,
126 k_proj,
127 v_proj,
128 o_proj,
129 num_heads: num_heads / comm.world_size(),
130 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131 head_dim,
132 rotary_emb_global,
133 rotary_emb_local,
134 use_sliding_window: sliding_window.is_some(),
135 paged_attn,
136 sdpa_params: SdpaParams {
137 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
138 cfg.num_key_value_heads,
139 cfg.num_attention_heads,
140 comm,
141 ),
142 use_flash_attn: cfg.use_flash_attn,
143 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
144 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
145 sliding_window,
146 },
147 q_norm,
148 k_norm,
149 })
150 }
151
152 #[allow(clippy::too_many_arguments)]
153 fn forward(
154 &self,
155 xs: &Tensor,
156 attention_mask: Option<&Tensor>,
157 sliding_attention_mask: Option<&Tensor>,
158 seqlen_offsets: &[usize],
159 kv_cache: &mut KvCache,
160 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
161 flash_params: &FlashParams,
162 ) -> Result<Tensor> {
163 let (b_sz, q_len, _) = xs.dims3()?;
164
165 let original_dtype = xs.dtype();
166 let mut xs = xs.clone();
167 if let Some(t) = self.q_proj.quantized_act_type() {
168 xs = xs.to_dtype(t)?;
169 }
170 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
171 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
172 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
173 if self.q_proj.quantized_act_type().is_some() {
174 q = q.to_dtype(original_dtype)?;
175 k = k.to_dtype(original_dtype)?;
176 v = v.to_dtype(original_dtype)?;
177 }
178
179 (q, k, v) = if q_len != 1 {
180 let q = q
181 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
182 .transpose(1, 2)?;
183 let k = k
184 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
185 .transpose(1, 2)?;
186 let v = v
187 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
188 .transpose(1, 2)?;
189 (q, k, v)
190 } else {
191 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
192 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
193 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
194 (q, k, v)
195 };
196
197 q = q.apply(&self.q_norm)?;
198 k = k.apply(&self.k_norm)?;
199
200 (q, k) = match self.use_sliding_window {
201 true => self.rotary_emb_local.forward(&q, &k, seqlen_offsets)?,
202 false => self.rotary_emb_global.forward(&q, &k, seqlen_offsets)?,
203 };
204
205 let mask = if self.use_sliding_window {
206 sliding_attention_mask
207 } else {
208 attention_mask
209 };
210
211 let mut attn_output = match &self.paged_attn {
212 Some(paged_attn) => match metadata {
213 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
214 &q,
215 &k,
216 &v,
217 attention_mask,
218 Some(key_cache),
219 Some(value_cache),
220 input_metadata,
221 &self.sdpa_params,
222 Some(flash_params),
223 )?,
224 None => {
225 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
228 assert!(attention_mask.is_some());
230 paged_attn.forward(
231 &q,
232 &k,
233 &v,
234 attention_mask,
235 None,
236 None,
237 &input_metadata,
238 &self.sdpa_params,
239 Some(flash_params),
240 )?
241 }
242 },
243 None => {
244 let (k, v) = kv_cache.append(&k, &v)?;
246
247 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
248 }
249 };
250
251 if let Some(t) = self.q_proj.quantized_act_type() {
252 attn_output = attn_output.to_dtype(t)?;
253 }
254 attn_output = if attention_mask.is_some() {
255 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
256 } else {
257 attn_output.reshape((b_sz, q_len, ()))?
258 };
259 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
260 if self.q_proj.quantized_act_type().is_some() {
261 res = res.to_dtype(original_dtype)?;
262 }
263 Ok(res)
264 }
265}
266
267struct DecoderLayer {
268 self_attn: Attention,
269 mlp: Box<dyn MlpLayer>,
270 input_layernorm: RmsNorm,
271 post_attention_layernorm: RmsNorm,
272 pre_feedforward_layernorm: RmsNorm,
273 post_feedforward_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277 #[allow(clippy::too_many_arguments)]
278 fn new(
279 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
280 rotary_emb_local: Arc<RotaryEmbedding>,
281 cfg: &Gemma3TextConfig,
282 vb: ShardedVarBuilder,
283 mapper: &dyn DeviceMapper,
284 layer_idx: usize,
285 loading_isq: bool,
286 paged_attn: Option<PagedAttention>,
287 comm: &Arc<mistralrs_quant::Comm>,
288 ) -> Result<Self> {
289 let self_attn = Attention::new(
290 rotary_emb_global,
291 rotary_emb_local,
292 cfg,
293 layer_idx,
294 mapper,
295 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
296 paged_attn,
297 comm,
298 )?;
299 let mlp = Mlp::new(
300 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
301 cfg.hidden_size,
302 cfg.intermediate_size,
303 &cfg.quantization_config,
304 cfg.hidden_activation,
305 comm,
306 )?;
307 let input_layernorm = RmsNorm::new_gemma(
308 cfg.hidden_size,
309 cfg.rms_norm_eps,
310 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
311 )?;
312 let post_attention_layernorm = RmsNorm::new_gemma(
313 cfg.hidden_size,
314 cfg.rms_norm_eps,
315 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
316 )?;
317 let pre_feedforward_layernorm = RmsNorm::new_gemma(
318 cfg.hidden_size,
319 cfg.rms_norm_eps,
320 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
321 )?;
322 let post_feedforward_layernorm = RmsNorm::new_gemma(
323 cfg.hidden_size,
324 cfg.rms_norm_eps,
325 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
326 )?;
327 Ok(Self {
328 self_attn,
329 mlp: Box::new(mlp),
330 input_layernorm,
331 post_attention_layernorm,
332 pre_feedforward_layernorm,
333 post_feedforward_layernorm,
334 })
335 }
336
337 #[allow(clippy::too_many_arguments)]
338 fn forward(
339 &self,
340 xs: &Tensor,
341 attention_mask: Option<&Tensor>,
342 sliding_attention_mask: Option<&Tensor>,
343 seqlen_offsets: &[usize],
344 kv_cache: &mut KvCache,
345 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
346 flash_params: &FlashParams,
347 ) -> Result<Tensor> {
348 let residual = xs;
349 let xs = self.input_layernorm.forward(xs)?;
350 let xs = self
351 .self_attn
352 .forward(
353 &xs,
354 attention_mask,
355 sliding_attention_mask,
356 seqlen_offsets,
357 kv_cache,
358 metadata,
359 flash_params,
360 )?
361 .apply(&self.post_attention_layernorm)?;
362 let xs = (xs + residual)?;
363 let residual = &xs;
364 let xs = self
365 .mlp
366 .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
367 .apply(&self.post_feedforward_layernorm)?;
368 residual + xs
369 }
370}
371
372pub struct TextModel {
373 embed_tokens: ScaledEmbedding,
374 layers: Vec<DecoderLayer>,
375 norm: RmsNorm,
376 lm_head: Arc<dyn QuantMethod>,
377 device: Device,
378 cache: EitherCache,
379 max_seq_len: usize,
380 mapper: Box<dyn DeviceMapper + Send + Sync>,
381 sliding_window: usize,
382 final_logit_softcapping: Option<f64>,
383 cfg: ModelConfigMetadata,
384}
385
386impl TextModel {
387 pub fn new(
388 cfg: &Gemma3TextConfig,
389 vb: ShardedVarBuilder,
390 is_gptx: bool,
391 normal_loading_metadata: NormalLoadingMetadata,
392 attention_mechanism: AttentionImplementation,
393 ) -> Result<Self> {
394 if let Some(ref quant_cfg) = &cfg.quantization_config {
395 tracing::info!(
396 "Using {} quantization: {}.",
397 quant_cfg.quant_method.to_string(),
398 quant_cfg.get_bits_name(&vb)
399 );
400 }
401 let mapper = normal_loading_metadata.mapper;
402
403 let vb_m = vb.pp("model");
404 let embed_tokens = ScaledEmbedding::new(
405 (cfg.hidden_size as f64).sqrt(),
406 embedding(
407 cfg.vocab_size,
408 cfg.hidden_size,
409 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
410 )?,
411 );
412
413 let mut global_ropes = HashMap::new();
414 for layer_idx in 0..cfg.num_hidden_layers {
415 let device = mapper
416 .device_for(layer_idx, false)
417 .unwrap_or(&normal_loading_metadata.real_device);
418 global_ropes.insert(
419 device.location(),
420 Arc::new(Gemma3RotaryEmbedding::new(
421 is_gptx,
422 vb.dtype(),
423 cfg,
424 device,
425 )?),
426 );
427 }
428
429 let mut local_ropes = HashMap::new();
430 for layer_idx in 0..cfg.num_hidden_layers {
431 let device = mapper
432 .device_for(layer_idx, false)
433 .unwrap_or(&normal_loading_metadata.real_device);
434 local_ropes.insert(
435 device.location(),
436 Arc::new(RotaryEmbedding::new(
437 cfg.rope_local_base_freq as f32,
438 cfg.head_dim,
439 cfg.max_position_embeddings,
440 device,
441 is_gptx,
442 vb_m.dtype(),
443 )?),
444 );
445 }
446
447 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
448 let vb_l = vb_m.pp("layers");
449 for layer_idx in NiceProgressBar::<_, 'b'>(
450 0..cfg.num_hidden_layers,
451 "Loading repeating layers",
452 &normal_loading_metadata.multi_progress,
453 ) {
454 let device = mapper
455 .device_for(layer_idx, false)
456 .unwrap_or(&normal_loading_metadata.real_device);
457 let rotary_emb_global = global_ropes
458 .get(&device.location())
459 .expect("No RoPE for device location!")
460 .clone();
461 let rotary_emb_local = local_ropes
462 .get(&device.location())
463 .expect("No RoPE for device location!")
464 .clone();
465 let paged_attn = match &attention_mechanism {
466 AttentionImplementation::Eager => None,
467 AttentionImplementation::PagedAttention => {
468 Some(PagedAttention::new(cfg.head_dim, device, None)?)
469 }
470 };
471 let comm = mapper.get_comm_for(layer_idx)?;
472 let layer = DecoderLayer::new(
473 rotary_emb_global.clone(),
474 rotary_emb_local.clone(),
475 cfg,
476 vb_l.pp(layer_idx),
477 &*mapper,
478 layer_idx,
479 normal_loading_metadata.loading_isq,
480 paged_attn,
481 &comm,
482 )?;
483 layers.push(layer)
484 }
485 let norm = RmsNorm::new_gemma(
486 cfg.hidden_size,
487 cfg.rms_norm_eps,
488 mapper.set_nm_device(vb_m.pp("norm"), false),
489 )?;
490
491 let lm_head = if !cfg.tie_word_embeddings {
492 ReplicatedLayer::new(
493 cfg.hidden_size,
494 cfg.vocab_size,
495 &None,
496 false,
497 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
498 )?
499 } else {
500 ReplicatedLayer::from_linear(candle_nn::Linear::new(
501 mapper.cast_nm_device(
502 embed_tokens.embeddings(),
503 normal_loading_metadata.loading_isq,
504 )?,
505 None,
506 ))?
507 };
508 let cache_types = (0..cfg.num_hidden_layers)
509 .map(|layer_idx| {
510 is_sliding!(layer_idx, cfg)
511 .then(|| NormalCacheType::SlidingWindow {
512 window: cfg.sliding_window,
513 })
514 .unwrap_or(NormalCacheType::Normal {
515 max_seq_len: cfg.max_position_embeddings,
516 })
517 })
518 .collect::<Vec<_>>();
519 Ok(Self {
520 embed_tokens,
521 layers,
522 norm,
523 lm_head,
524 device: normal_loading_metadata.real_device,
525 cache: EitherCache::Normal(NormalCache::from_types(cache_types)),
526 max_seq_len: cfg.max_position_embeddings,
527 sliding_window: cfg.sliding_window,
528 final_logit_softcapping: cfg.final_logit_softcapping,
529 cfg: ModelConfigMetadata {
530 max_seq_len: cfg.max_position_embeddings,
531 num_layers: cfg.num_hidden_layers,
532 hidden_size: cfg.hidden_size,
533 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535 .max(1),
536 sliding_window: None,
537 k_head_dim: cfg.head_dim,
538 v_head_dim: cfg.head_dim,
539 },
540 mapper,
541 })
542 }
543
544 pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
545 self.embed_tokens.forward(input_ids)
546 }
547
548 pub fn forward_embeds(
549 &self,
550 input_ids: &Tensor,
551 mut xs: Tensor,
552 seqlen_offsets: &[usize],
553 context_lens: Vec<(usize, usize)>,
554 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
555 flash_params: &FlashParams,
556 ) -> Result<Tensor> {
557 let cache = &mut self.cache.normal().0;
558 let attention_mask = CausalMasker.make_causal_mask_matrix(
559 input_ids,
560 &*cache,
561 xs.dtype(),
562 self.cfg.num_attn_heads,
563 )?;
564 let attention_mask = attention_mask.filter(|_| {
566 metadata
567 .as_ref()
568 .map(|(_, meta)| meta.is_first_prompt_chunk)
569 .unwrap_or(true)
570 });
571 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
572 input_ids,
573 &*cache,
574 Some(self.sliding_window),
575 xs.dtype(),
576 self.cfg.num_attn_heads,
577 )?;
578 let sliding_attention_mask = sliding_attention_mask.filter(|_| {
580 metadata
581 .as_ref()
582 .map(|(_, meta)| meta.is_first_prompt_chunk)
583 .unwrap_or(true)
584 });
585 for (i, layer) in self.layers.iter().enumerate() {
586 xs = self.mapper.map(xs, i)?;
587 xs = layer.forward(
588 &xs,
589 attention_mask
590 .as_ref()
591 .map(|m| m.to_device(xs.device()).unwrap())
592 .as_ref(),
593 sliding_attention_mask
594 .as_ref()
595 .map(|m| m.to_device(xs.device()).unwrap())
596 .as_ref(),
597 seqlen_offsets,
598 &mut cache[i],
599 metadata
600 .as_ref()
601 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
602 flash_params,
603 )?;
604 }
605 let xs = xs.to_device(&self.device)?;
606 let mut xs = xs.apply(&self.norm)?;
607 if let Some(t) = self.lm_head.quantized_act_type() {
608 xs = xs.to_dtype(t)?;
609 }
610
611 let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
612
613 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
614 xs = (xs / final_logit_softcapping)?;
615 xs = xs.tanh()?;
616 xs = (xs * final_logit_softcapping)?;
617 }
618
619 extract_logits(&xs, context_lens)
620 }
621}
622
623impl IsqModel for TextModel {
624 fn get_layers(
625 &mut self,
626 ) -> (
627 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
628 &dyn DeviceMapper,
629 ) {
630 let mut tensors = Vec::new();
631 tensors.push((&mut self.lm_head, None));
632 for (i, layer) in self.layers.iter_mut().enumerate() {
633 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
634 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
635 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
636 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
637 tensors.extend(
638 layer
639 .mlp
640 .get_isq_layers()
641 .into_iter()
642 .map(|m| (m, Some(i)))
643 .collect::<Vec<_>>(),
644 );
645 }
646 (tensors, &*self.mapper)
647 }
648
649 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
650 let uvb = UnVarBuilder::new();
651
652 let uvb_m = uvb.pp("model");
653 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
654 uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
655
656 for (layer_idx, layer) in self.layers.iter().enumerate() {
657 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
658 uvb_l
659 .pp("input_layernorm")
660 .add(&layer.input_layernorm.undo_gemma().unwrap());
661 uvb_l
662 .pp("post_attention_layernorm")
663 .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
664 uvb_l
665 .pp("pre_feedforward_layernorm")
666 .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
667 uvb_l
668 .pp("post_feedforward_layernorm")
669 .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
670 }
671
672 uvb.to_safetensors()
673 }
674
675 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
676 let mut names = Vec::new();
678 names.push(None);
680 for i in 0..self.layers.len() {
681 names.push(Some(format!("blk.{i}.attn_q.weight")));
682 names.push(Some(format!("blk.{i}.attn_k.weight")));
683 names.push(Some(format!("blk.{i}.attn_v.weight")));
684 names.push(Some(format!("blk.{i}.attn_output.weight")));
685 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
686 names.push(Some(format!("blk.{i}.ffn_up.weight")));
687 names.push(Some(format!("blk.{i}.ffn_down.weight")));
688 }
689 Ok(names)
690 }
691}
692
693impl VisionModel for TextModel {
694 fn forward(
695 &self,
696 _input_ids: &Tensor,
697 _pixel_values: Option<Tensor>,
698 _seqlen_offsets: &[usize],
699 _context_lens: Vec<(usize, usize)>,
700 _position_ids: Vec<usize>,
701 _model_specific_args: Box<dyn std::any::Any>, _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
703 _flash_params: &FlashParams,
704 ) -> candle_core::Result<Tensor> {
705 unreachable!()
706 }
707 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
708 unreachable!()
709 }
710 fn cache(&self) -> &EitherCache {
711 &self.cache
712 }
713 fn cache_mut(&mut self) -> &mut EitherCache {
714 &mut self.cache
715 }
716 fn device(&self) -> &Device {
717 &self.device
718 }
719 fn max_seq_len(&self) -> usize {
720 self.max_seq_len
721 }
722 fn config(&self) -> &ModelConfigMetadata {
723 &self.cfg
724 }
725 fn has_conv2d(&self) -> bool {
726 unreachable!()
727 }
728}
729
730impl AnyMoeBaseModelMixin for TextModel {
731 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
732 let mut mlps = Vec::new();
733 for layer in &self.layers {
734 mlps.push(&*layer.mlp);
735 }
736 mlps
737 }
738 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
739 let mut mlps = Vec::new();
740 for layer in &mut self.layers {
741 mlps.push(&mut layer.mlp);
742 }
743 mlps
744 }
745 fn create_anymoe_layers(
746 &mut self,
747 additional_vbs: Vec<ShardedVarBuilder>,
748 config: AnyMoeConfig,
749 (prefix, mlp): (String, String),
750 mut layers: Vec<usize>,
751 expert_type: AnyMoeExpertType,
752 gate_vb: Option<ShardedVarBuilder>,
753 ) -> Result<()> {
754 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
755 if layers.is_empty() {
756 layers = (0..self.layers.len()).collect::<Vec<_>>();
757 }
758 for _ in 0..layers.len() {
759 experts.push(Vec::new());
760 }
761 for vb in additional_vbs {
762 let vb = vb.pp(&prefix);
763 for (layer, row) in experts.iter_mut().enumerate() {
764 if !layers.contains(&layer) {
765 continue;
766 }
767
768 let intermediate_size = self.layers[layer].mlp.get_params()[1];
769 let hidden_size = self.layers[layer].mlp.get_params()[0];
770 match expert_type {
771 AnyMoeExpertType::FineTuned => {
772 let (dtype, device) = self.layers[layer].mlp.dtype_device();
773 row.push(Box::new(Mlp::replicate(
774 self.layers[layer].mlp.get_params(),
775 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
776 self.layers[layer].mlp.hidden_act(),
777 &self.mapper.get_comm_for(layer)?,
778 )?));
779 }
780 AnyMoeExpertType::LoraAdapter {
781 rank,
782 alpha,
783 ref target_modules,
784 } => {
785 let vb_mlp = vb.pp(layer).pp(&mlp);
786
787 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
788 Some(get_delta_from_lora_ab!(
789 vb_mlp,
790 rank,
791 alpha,
792 (hidden_size, intermediate_size),
793 "gate_proj"
794 ))
795 } else {
796 None
797 };
798 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
799 Some(get_delta_from_lora_ab!(
800 vb_mlp,
801 rank,
802 alpha,
803 (hidden_size, intermediate_size),
804 "up_proj"
805 ))
806 } else {
807 None
808 };
809 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
810 Some(get_delta_from_lora_ab!(
811 vb_mlp,
812 rank,
813 alpha,
814 (intermediate_size, hidden_size),
815 "down_proj"
816 ))
817 } else {
818 None
819 };
820
821 row.push(self.layers[layer].mlp.new_added_delta(vec![
822 gate_proj_delta,
823 up_proj_delta,
824 down_proj_delta,
825 ])?);
826 }
827 }
828 }
829 }
830 for (layer, expert) in layers.into_iter().zip(experts) {
831 let mut experts_all = vec![self.layers[layer].mlp.clone()];
832 experts_all.extend(expert);
833 let (dtype, device) = self.layers[layer].mlp.dtype_device();
834 self.layers[layer].mlp = Box::new(MoeMlp::new(
835 experts_all,
836 config.clone(),
837 dtype,
838 &device,
839 layer,
840 gate_vb.as_ref(),
841 )?);
842 }
843 Ok(())
844 }
845 fn amoe_supported(&self) -> bool {
846 true
847 }
848}