1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{Device, Module, Result, Tensor};
4use mistralrs_quant::{
5 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
6};
7
8use crate::{
9 amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
10 attention::SdpaParams,
11 device_map::DeviceMapper,
12 get_delta_from_lora_ab,
13 layers::{
14 embedding, CausalMasker, Gemma3RotaryEmbedding, MatMul, Mlp, RmsNorm, RotaryEmbedding,
15 ScaledEmbedding, Sdpa,
16 },
17 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 EitherCache, IsqModel, KvCache, NormalCache, NormalCacheType, NormalLoadingMetadata,
22 VisionModel,
23 },
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27use super::config::Gemma3TextConfig;
28
29macro_rules! is_sliding {
30 ($layer_idx:expr, $cfg:expr) => {
31 ($layer_idx + 1) % $cfg.sliding_window_pattern != 0
32 };
33}
34
35struct Attention {
36 q_proj: Arc<dyn QuantMethod>,
37 k_proj: Arc<dyn QuantMethod>,
38 v_proj: Arc<dyn QuantMethod>,
39 o_proj: Arc<dyn QuantMethod>,
40 num_heads: usize,
41 num_kv_heads: usize,
42 head_dim: usize,
43 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
44 rotary_emb_local: Arc<RotaryEmbedding>,
45 use_sliding_window: bool,
46 paged_attn: Option<PagedAttention>,
47 sdpa_params: SdpaParams,
48 q_norm: RmsNorm,
49 k_norm: RmsNorm,
50}
51
52impl Attention {
53 #[allow(clippy::too_many_arguments)]
54 fn new(
55 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
56 rotary_emb_local: Arc<RotaryEmbedding>,
57 cfg: &Gemma3TextConfig,
58 layer_idx: usize,
59 mapper: &dyn DeviceMapper,
60 vb: ShardedVarBuilder,
61 paged_attn: Option<PagedAttention>,
62 comm: &Arc<mistralrs_quant::Comm>,
63 ) -> Result<Self> {
64 let hidden_sz = cfg.hidden_size;
65 let num_heads = cfg.num_attention_heads;
66 let num_kv_heads = cfg.num_key_value_heads;
67 let head_dim = cfg.head_dim;
68 let bias = cfg.attention_bias;
69 let q_proj = ColumnParallelLayer::new(
70 hidden_sz,
71 num_heads * head_dim,
72 &cfg.quantization_config,
73 bias,
74 comm,
75 vb.pp("q_proj"),
76 )?;
77 let kv_shard = mistralrs_quant::compute_kv_shard(
78 cfg.num_key_value_heads,
79 cfg.hidden_size / cfg.num_attention_heads,
80 comm,
81 );
82 let k_proj = ColumnParallelLayer::new_with_shard(
83 hidden_sz,
84 num_kv_heads * head_dim,
85 &cfg.quantization_config,
86 bias,
87 comm,
88 kv_shard,
89 vb.pp("k_proj"),
90 )?;
91 let v_proj = ColumnParallelLayer::new_with_shard(
92 hidden_sz,
93 num_kv_heads * head_dim,
94 &cfg.quantization_config,
95 bias,
96 comm,
97 kv_shard,
98 vb.pp("v_proj"),
99 )?;
100 let o_proj = RowParallelLayer::new(
101 num_heads * head_dim,
102 hidden_sz,
103 &cfg.quantization_config,
104 bias,
105 comm,
106 vb.pp("o_proj"),
107 )?;
108 let sliding_window = if is_sliding!(layer_idx, cfg) {
109 Some(cfg.sliding_window)
110 } else {
111 None
112 };
113
114 let q_norm = RmsNorm::new_gemma(
115 cfg.head_dim,
116 cfg.rms_norm_eps,
117 mapper.set_device(layer_idx, vb.pp("q_norm"), false),
118 )?;
119 let k_norm = RmsNorm::new_gemma(
120 cfg.head_dim,
121 cfg.rms_norm_eps,
122 mapper.set_device(layer_idx, vb.pp("k_norm"), false),
123 )?;
124 Ok(Self {
125 q_proj,
126 k_proj,
127 v_proj,
128 o_proj,
129 num_heads: num_heads / comm.world_size(),
130 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131 head_dim,
132 rotary_emb_global,
133 rotary_emb_local,
134 use_sliding_window: sliding_window.is_some(),
135 paged_attn,
136 sdpa_params: SdpaParams {
137 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
138 cfg.num_key_value_heads,
139 cfg.num_attention_heads,
140 comm,
141 ),
142 use_flash_attn: cfg.use_flash_attn,
143 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
144 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
145 sliding_window,
146 },
147 q_norm,
148 k_norm,
149 })
150 }
151
152 #[allow(clippy::too_many_arguments)]
153 fn forward(
154 &self,
155 xs: &Tensor,
156 attention_mask: Option<&Tensor>,
157 sliding_attention_mask: Option<&Tensor>,
158 seqlen_offsets: &[usize],
159 kv_cache: &mut KvCache,
160 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
161 flash_params: &FlashParams,
162 ) -> Result<Tensor> {
163 let (b_sz, q_len, _) = xs.dims3()?;
164
165 let original_dtype = xs.dtype();
166 let mut xs = xs.clone();
167 if let Some(t) = self.q_proj.quantized_act_type() {
168 xs = xs.to_dtype(t)?;
169 }
170 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
171 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
172 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
173 if self.q_proj.quantized_act_type().is_some() {
174 q = q.to_dtype(original_dtype)?;
175 k = k.to_dtype(original_dtype)?;
176 v = v.to_dtype(original_dtype)?;
177 }
178
179 (q, k, v) = if q_len != 1 {
180 let q = q
181 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
182 .transpose(1, 2)?;
183 let k = k
184 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
185 .transpose(1, 2)?;
186 let v = v
187 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
188 .transpose(1, 2)?;
189 (q, k, v)
190 } else {
191 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
192 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
193 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
194 (q, k, v)
195 };
196
197 q = q.apply(&self.q_norm)?;
198 k = k.apply(&self.k_norm)?;
199
200 (q, k) = match self.use_sliding_window {
201 true => self.rotary_emb_local.forward(&q, &k, seqlen_offsets)?,
202 false => self.rotary_emb_global.forward(&q, &k, seqlen_offsets)?,
203 };
204
205 let mask = if self.use_sliding_window {
206 sliding_attention_mask
207 } else {
208 attention_mask
209 };
210
211 let mut attn_output = match &self.paged_attn {
212 Some(paged_attn) => match metadata {
213 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
214 &q,
215 &k,
216 &v,
217 attention_mask,
218 Some(key_cache),
219 Some(value_cache),
220 input_metadata,
221 &self.sdpa_params,
222 Some(flash_params),
223 )?,
224 None => {
225 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
228 assert!(attention_mask.is_some());
230 paged_attn.forward(
231 &q,
232 &k,
233 &v,
234 attention_mask,
235 None,
236 None,
237 &input_metadata,
238 &self.sdpa_params,
239 Some(flash_params),
240 )?
241 }
242 },
243 None => {
244 let (k, v) = kv_cache.append(&k, &v)?;
246
247 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
248 }
249 };
250
251 if let Some(t) = self.q_proj.quantized_act_type() {
252 attn_output = attn_output.to_dtype(t)?;
253 }
254 attn_output = if attention_mask.is_some() {
255 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
256 } else {
257 attn_output.reshape((b_sz, q_len, ()))?
258 };
259 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
260 if self.q_proj.quantized_act_type().is_some() {
261 res = res.to_dtype(original_dtype)?;
262 }
263 Ok(res)
264 }
265}
266
267struct DecoderLayer {
268 self_attn: Attention,
269 mlp: Box<dyn MlpLayer>,
270 input_layernorm: RmsNorm,
271 post_attention_layernorm: RmsNorm,
272 pre_feedforward_layernorm: RmsNorm,
273 post_feedforward_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277 #[allow(clippy::too_many_arguments)]
278 fn new(
279 rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
280 rotary_emb_local: Arc<RotaryEmbedding>,
281 cfg: &Gemma3TextConfig,
282 vb: ShardedVarBuilder,
283 mapper: &dyn DeviceMapper,
284 layer_idx: usize,
285 loading_isq: bool,
286 paged_attn: Option<PagedAttention>,
287 comm: &Arc<mistralrs_quant::Comm>,
288 ) -> Result<Self> {
289 let self_attn = Attention::new(
290 rotary_emb_global,
291 rotary_emb_local,
292 cfg,
293 layer_idx,
294 mapper,
295 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
296 paged_attn,
297 comm,
298 )?;
299 let mlp = Mlp::new(
300 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
301 cfg.hidden_size,
302 cfg.intermediate_size,
303 &cfg.quantization_config,
304 cfg.hidden_activation,
305 comm,
306 )?;
307 let input_layernorm = RmsNorm::new_gemma(
308 cfg.hidden_size,
309 cfg.rms_norm_eps,
310 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
311 )?;
312 let post_attention_layernorm = RmsNorm::new_gemma(
313 cfg.hidden_size,
314 cfg.rms_norm_eps,
315 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
316 )?;
317 let pre_feedforward_layernorm = RmsNorm::new_gemma(
318 cfg.hidden_size,
319 cfg.rms_norm_eps,
320 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
321 )?;
322 let post_feedforward_layernorm = RmsNorm::new_gemma(
323 cfg.hidden_size,
324 cfg.rms_norm_eps,
325 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
326 )?;
327 Ok(Self {
328 self_attn,
329 mlp: Box::new(mlp),
330 input_layernorm,
331 post_attention_layernorm,
332 pre_feedforward_layernorm,
333 post_feedforward_layernorm,
334 })
335 }
336
337 #[allow(clippy::too_many_arguments)]
338 fn forward(
339 &self,
340 xs: &Tensor,
341 attention_mask: Option<&Tensor>,
342 sliding_attention_mask: Option<&Tensor>,
343 seqlen_offsets: &[usize],
344 kv_cache: &mut KvCache,
345 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
346 flash_params: &FlashParams,
347 ) -> Result<Tensor> {
348 let residual = xs;
349 let xs = self.input_layernorm.forward(xs)?;
350 let xs = self
351 .self_attn
352 .forward(
353 &xs,
354 attention_mask,
355 sliding_attention_mask,
356 seqlen_offsets,
357 kv_cache,
358 metadata,
359 flash_params,
360 )?
361 .apply(&self.post_attention_layernorm)?;
362 let xs = (xs + residual)?;
363 let residual = &xs;
364 let xs = self
365 .mlp
366 .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
367 .apply(&self.post_feedforward_layernorm)?;
368 residual + xs
369 }
370}
371
372pub struct TextModel {
373 embed_tokens: ScaledEmbedding,
374 layers: Vec<DecoderLayer>,
375 norm: RmsNorm,
376 lm_head: Arc<dyn QuantMethod>,
377 device: Device,
378 cache: EitherCache,
379 max_seq_len: usize,
380 mapper: Box<dyn DeviceMapper + Send + Sync>,
381 sliding_window: usize,
382 final_logit_softcapping: Option<f64>,
383 cfg: ModelConfigMetadata,
384}
385
386impl TextModel {
387 pub fn new(
388 cfg: &Gemma3TextConfig,
389 vb: ShardedVarBuilder,
390 is_gptx: bool,
391 normal_loading_metadata: NormalLoadingMetadata,
392 attention_mechanism: AttentionImplementation,
393 ) -> Result<Self> {
394 if let Some(ref quant_cfg) = &cfg.quantization_config {
395 tracing::info!(
396 "Using {} quantization: {}.",
397 quant_cfg.name(),
398 quant_cfg.get_bits_name(&vb)
399 );
400 }
401 let mapper = normal_loading_metadata.mapper;
402
403 let vb_m = vb.pp("model");
404 let embed_tokens = ScaledEmbedding::new(
405 (cfg.hidden_size as f64).sqrt(),
406 embedding(
407 cfg.vocab_size,
408 cfg.hidden_size,
409 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
410 &cfg.quantization_config,
411 )?,
412 );
413
414 let mut global_ropes = HashMap::new();
415 for layer_idx in 0..cfg.num_hidden_layers {
416 let device = mapper
417 .device_for(layer_idx, false)
418 .unwrap_or(&normal_loading_metadata.real_device);
419 global_ropes.insert(
420 device.location(),
421 Arc::new(Gemma3RotaryEmbedding::new(
422 is_gptx,
423 vb.dtype(),
424 cfg,
425 device,
426 )?),
427 );
428 }
429
430 let mut local_ropes = HashMap::new();
431 for layer_idx in 0..cfg.num_hidden_layers {
432 let device = mapper
433 .device_for(layer_idx, false)
434 .unwrap_or(&normal_loading_metadata.real_device);
435 local_ropes.insert(
436 device.location(),
437 Arc::new(RotaryEmbedding::new(
438 cfg.rope_local_base_freq as f32,
439 cfg.head_dim,
440 cfg.max_position_embeddings,
441 device,
442 is_gptx,
443 vb_m.dtype(),
444 )?),
445 );
446 }
447
448 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
449 let vb_l = vb_m.pp("layers");
450 for layer_idx in NiceProgressBar::<_, 'b'>(
451 0..cfg.num_hidden_layers,
452 "Loading repeating layers",
453 &normal_loading_metadata.multi_progress,
454 ) {
455 let device = mapper
456 .device_for(layer_idx, false)
457 .unwrap_or(&normal_loading_metadata.real_device);
458 let rotary_emb_global = global_ropes
459 .get(&device.location())
460 .expect("No RoPE for device location!")
461 .clone();
462 let rotary_emb_local = local_ropes
463 .get(&device.location())
464 .expect("No RoPE for device location!")
465 .clone();
466 let paged_attn = match &attention_mechanism {
467 AttentionImplementation::Eager => None,
468 AttentionImplementation::PagedAttention => {
469 Some(PagedAttention::new(cfg.head_dim, device, None)?)
470 }
471 };
472 let comm = mapper.get_comm_for(layer_idx)?;
473 let layer = DecoderLayer::new(
474 rotary_emb_global.clone(),
475 rotary_emb_local.clone(),
476 cfg,
477 vb_l.pp(layer_idx),
478 &*mapper,
479 layer_idx,
480 normal_loading_metadata.loading_isq,
481 paged_attn,
482 &comm,
483 )?;
484 layers.push(layer)
485 }
486 let norm = RmsNorm::new_gemma(
487 cfg.hidden_size,
488 cfg.rms_norm_eps,
489 mapper.set_nm_device(vb_m.pp("norm"), false),
490 )?;
491
492 let lm_head = if !cfg.tie_word_embeddings {
493 ReplicatedLayer::new(
494 cfg.hidden_size,
495 cfg.vocab_size,
496 &None,
497 false,
498 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
499 )?
500 } else {
501 ReplicatedLayer::from_linear(candle_nn::Linear::new(
502 mapper.cast_nm_device(
503 embed_tokens.embeddings(),
504 normal_loading_metadata.loading_isq,
505 )?,
506 None,
507 ))?
508 };
509 let cache_types = (0..cfg.num_hidden_layers)
510 .map(|layer_idx| {
511 is_sliding!(layer_idx, cfg)
512 .then(|| NormalCacheType::SlidingWindow {
513 window: cfg.sliding_window,
514 })
515 .unwrap_or(NormalCacheType::Normal {
516 max_seq_len: cfg.max_position_embeddings,
517 })
518 })
519 .collect::<Vec<_>>();
520 Ok(Self {
521 embed_tokens,
522 layers,
523 norm,
524 lm_head,
525 device: normal_loading_metadata.real_device,
526 cache: EitherCache::Normal(NormalCache::from_types(cache_types)),
527 max_seq_len: cfg.max_position_embeddings,
528 sliding_window: cfg.sliding_window,
529 final_logit_softcapping: cfg.final_logit_softcapping,
530 cfg: ModelConfigMetadata {
531 max_seq_len: cfg.max_position_embeddings,
532 num_layers: cfg.num_hidden_layers,
533 hidden_size: cfg.hidden_size,
534 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
535 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
536 .max(1),
537 sliding_window: None,
538 k_head_dim: cfg.head_dim,
539 v_head_dim: cfg.head_dim,
540 },
541 mapper,
542 })
543 }
544
545 pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
546 self.embed_tokens.forward(input_ids)
547 }
548
549 pub fn forward_embeds(
550 &self,
551 input_ids: &Tensor,
552 mut xs: Tensor,
553 seqlen_offsets: &[usize],
554 context_lens: Vec<(usize, usize)>,
555 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
556 flash_params: &FlashParams,
557 ) -> Result<Tensor> {
558 let cache = &mut self.cache.normal().0;
559 let attention_mask = CausalMasker.make_causal_mask_matrix(
560 input_ids,
561 &*cache,
562 xs.dtype(),
563 self.cfg.num_attn_heads,
564 )?;
565 let attention_mask = attention_mask.filter(|_| {
567 metadata
568 .as_ref()
569 .map(|(_, meta)| meta.is_first_prompt_chunk)
570 .unwrap_or(true)
571 });
572 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
573 input_ids,
574 &*cache,
575 Some(self.sliding_window),
576 xs.dtype(),
577 self.cfg.num_attn_heads,
578 )?;
579 let sliding_attention_mask = sliding_attention_mask.filter(|_| {
581 metadata
582 .as_ref()
583 .map(|(_, meta)| meta.is_first_prompt_chunk)
584 .unwrap_or(true)
585 });
586 for (i, layer) in self.layers.iter().enumerate() {
587 xs = self.mapper.map(xs, i)?;
588 xs = layer.forward(
589 &xs,
590 attention_mask
591 .as_ref()
592 .map(|m| m.to_device(xs.device()).unwrap())
593 .as_ref(),
594 sliding_attention_mask
595 .as_ref()
596 .map(|m| m.to_device(xs.device()).unwrap())
597 .as_ref(),
598 seqlen_offsets,
599 &mut cache[i],
600 metadata
601 .as_ref()
602 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
603 flash_params,
604 )?;
605 }
606 let xs = xs.to_device(&self.device)?;
607 let mut xs = xs.apply(&self.norm)?;
608 if let Some(t) = self.lm_head.quantized_act_type() {
609 xs = xs.to_dtype(t)?;
610 }
611
612 let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
613
614 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
615 xs = (xs / final_logit_softcapping)?;
616 xs = xs.tanh()?;
617 xs = (xs * final_logit_softcapping)?;
618 }
619
620 extract_logits(&xs, context_lens)
621 }
622}
623
624impl IsqModel for TextModel {
625 fn get_layers(
626 &mut self,
627 ) -> (
628 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
629 &dyn DeviceMapper,
630 ) {
631 let mut tensors = Vec::new();
632 tensors.push((&mut self.lm_head, None));
633 for (i, layer) in self.layers.iter_mut().enumerate() {
634 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
635 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
636 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
637 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
638 tensors.extend(
639 layer
640 .mlp
641 .get_isq_layers()
642 .into_iter()
643 .map(|m| (m, Some(i)))
644 .collect::<Vec<_>>(),
645 );
646 }
647 (tensors, &*self.mapper)
648 }
649
650 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
651 let uvb = UnVarBuilder::new();
652
653 let uvb_m = uvb.pp("model");
654 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
655 uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
656
657 for (layer_idx, layer) in self.layers.iter().enumerate() {
658 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
659 uvb_l
660 .pp("input_layernorm")
661 .add(&layer.input_layernorm.undo_gemma().unwrap());
662 uvb_l
663 .pp("post_attention_layernorm")
664 .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
665 uvb_l
666 .pp("pre_feedforward_layernorm")
667 .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
668 uvb_l
669 .pp("post_feedforward_layernorm")
670 .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
671 }
672
673 uvb.to_safetensors()
674 }
675
676 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
677 let mut names = Vec::new();
679 names.push(None);
681 for i in 0..self.layers.len() {
682 names.push(Some(format!("blk.{i}.attn_q.weight")));
683 names.push(Some(format!("blk.{i}.attn_k.weight")));
684 names.push(Some(format!("blk.{i}.attn_v.weight")));
685 names.push(Some(format!("blk.{i}.attn_output.weight")));
686 names.push(Some(format!("blk.{i}.ffn_gate.weight")));
687 names.push(Some(format!("blk.{i}.ffn_up.weight")));
688 names.push(Some(format!("blk.{i}.ffn_down.weight")));
689 }
690 Ok(names)
691 }
692}
693
694impl VisionModel for TextModel {
695 fn forward(
696 &self,
697 _input_ids: &Tensor,
698 _pixel_values: Option<Tensor>,
699 _seqlen_offsets: &[usize],
700 _context_lens: Vec<(usize, usize)>,
701 _position_ids: Vec<usize>,
702 _model_specific_args: Box<dyn std::any::Any>, _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
704 _flash_params: &FlashParams,
705 ) -> candle_core::Result<Tensor> {
706 unreachable!()
707 }
708 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
709 unreachable!()
710 }
711 fn cache(&self) -> &EitherCache {
712 &self.cache
713 }
714 fn cache_mut(&mut self) -> &mut EitherCache {
715 &mut self.cache
716 }
717 fn device(&self) -> &Device {
718 &self.device
719 }
720 fn max_seq_len(&self) -> usize {
721 self.max_seq_len
722 }
723 fn config(&self) -> &ModelConfigMetadata {
724 &self.cfg
725 }
726 fn has_conv2d(&self) -> bool {
727 unreachable!()
728 }
729}
730
731impl AnyMoeBaseModelMixin for TextModel {
732 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
733 let mut mlps = Vec::new();
734 for layer in &self.layers {
735 mlps.push(&*layer.mlp);
736 }
737 mlps
738 }
739 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
740 let mut mlps = Vec::new();
741 for layer in &mut self.layers {
742 mlps.push(&mut layer.mlp);
743 }
744 mlps
745 }
746 fn create_anymoe_layers(
747 &mut self,
748 additional_vbs: Vec<ShardedVarBuilder>,
749 config: AnyMoeConfig,
750 (prefix, mlp): (String, String),
751 mut layers: Vec<usize>,
752 expert_type: AnyMoeExpertType,
753 gate_vb: Option<ShardedVarBuilder>,
754 ) -> Result<()> {
755 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
756 if layers.is_empty() {
757 layers = (0..self.layers.len()).collect::<Vec<_>>();
758 }
759 for _ in 0..layers.len() {
760 experts.push(Vec::new());
761 }
762 for vb in additional_vbs {
763 let vb = vb.pp(&prefix);
764 for (layer, row) in experts.iter_mut().enumerate() {
765 if !layers.contains(&layer) {
766 continue;
767 }
768
769 let intermediate_size = self.layers[layer].mlp.get_params()[1];
770 let hidden_size = self.layers[layer].mlp.get_params()[0];
771 match expert_type {
772 AnyMoeExpertType::FineTuned => {
773 let (dtype, device) = self.layers[layer].mlp.dtype_device();
774 row.push(Box::new(Mlp::replicate(
775 self.layers[layer].mlp.get_params(),
776 vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
777 self.layers[layer].mlp.hidden_act(),
778 &self.mapper.get_comm_for(layer)?,
779 )?));
780 }
781 AnyMoeExpertType::LoraAdapter {
782 rank,
783 alpha,
784 ref target_modules,
785 } => {
786 let vb_mlp = vb.pp(layer).pp(&mlp);
787
788 let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
789 Some(get_delta_from_lora_ab!(
790 vb_mlp,
791 rank,
792 alpha,
793 (hidden_size, intermediate_size),
794 "gate_proj"
795 ))
796 } else {
797 None
798 };
799 let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
800 Some(get_delta_from_lora_ab!(
801 vb_mlp,
802 rank,
803 alpha,
804 (hidden_size, intermediate_size),
805 "up_proj"
806 ))
807 } else {
808 None
809 };
810 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
811 Some(get_delta_from_lora_ab!(
812 vb_mlp,
813 rank,
814 alpha,
815 (intermediate_size, hidden_size),
816 "down_proj"
817 ))
818 } else {
819 None
820 };
821
822 row.push(self.layers[layer].mlp.new_added_delta(vec![
823 gate_proj_delta,
824 up_proj_delta,
825 down_proj_delta,
826 ])?);
827 }
828 }
829 }
830 }
831 for (layer, expert) in layers.into_iter().zip(experts) {
832 let mut experts_all = vec![self.layers[layer].mlp.clone()];
833 experts_all.extend(expert);
834 let (dtype, device) = self.layers[layer].mlp.dtype_device();
835 self.layers[layer].mlp = Box::new(MoeMlp::new(
836 experts_all,
837 config.clone(),
838 dtype,
839 &device,
840 layer,
841 gate_vb.as_ref(),
842 )?);
843 }
844 Ok(())
845 }
846 fn amoe_supported(&self) -> bool {
847 true
848 }
849}