mistralrs_core/vision_models/gemma3/
text.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{Device, Module, Result, Tensor};
4use mistralrs_quant::{
5    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
6};
7
8use crate::{
9    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
10    attention::SdpaParams,
11    device_map::DeviceMapper,
12    get_delta_from_lora_ab,
13    layers::{
14        embedding, CausalMasker, Gemma3RotaryEmbedding, MatMul, Mlp, RmsNorm, RotaryEmbedding,
15        ScaledEmbedding, Sdpa,
16    },
17    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        EitherCache, IsqModel, KvCache, NormalCache, NormalCacheType, NormalLoadingMetadata,
22        VisionModel,
23    },
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27use super::config::Gemma3TextConfig;
28
29macro_rules! is_sliding {
30    ($layer_idx:expr, $cfg:expr) => {
31        ($layer_idx + 1) % $cfg.sliding_window_pattern != 0
32    };
33}
34
35struct Attention {
36    q_proj: Arc<dyn QuantMethod>,
37    k_proj: Arc<dyn QuantMethod>,
38    v_proj: Arc<dyn QuantMethod>,
39    o_proj: Arc<dyn QuantMethod>,
40    num_heads: usize,
41    num_kv_heads: usize,
42    head_dim: usize,
43    rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
44    rotary_emb_local: Arc<RotaryEmbedding>,
45    use_sliding_window: bool,
46    paged_attn: Option<PagedAttention>,
47    sdpa_params: SdpaParams,
48    q_norm: RmsNorm,
49    k_norm: RmsNorm,
50}
51
52impl Attention {
53    #[allow(clippy::too_many_arguments)]
54    fn new(
55        rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
56        rotary_emb_local: Arc<RotaryEmbedding>,
57        cfg: &Gemma3TextConfig,
58        layer_idx: usize,
59        mapper: &dyn DeviceMapper,
60        vb: ShardedVarBuilder,
61        paged_attn: Option<PagedAttention>,
62        comm: &Arc<mistralrs_quant::Comm>,
63    ) -> Result<Self> {
64        let hidden_sz = cfg.hidden_size;
65        let num_heads = cfg.num_attention_heads;
66        let num_kv_heads = cfg.num_key_value_heads;
67        let head_dim = cfg.head_dim;
68        let bias = cfg.attention_bias;
69        let q_proj = ColumnParallelLayer::new(
70            hidden_sz,
71            num_heads * head_dim,
72            &cfg.quantization_config,
73            bias,
74            comm,
75            vb.pp("q_proj"),
76        )?;
77        let kv_shard = mistralrs_quant::compute_kv_shard(
78            cfg.num_key_value_heads,
79            cfg.hidden_size / cfg.num_attention_heads,
80            comm,
81        );
82        let k_proj = ColumnParallelLayer::new_with_shard(
83            hidden_sz,
84            num_kv_heads * head_dim,
85            &cfg.quantization_config,
86            bias,
87            comm,
88            kv_shard,
89            vb.pp("k_proj"),
90        )?;
91        let v_proj = ColumnParallelLayer::new_with_shard(
92            hidden_sz,
93            num_kv_heads * head_dim,
94            &cfg.quantization_config,
95            bias,
96            comm,
97            kv_shard,
98            vb.pp("v_proj"),
99        )?;
100        let o_proj = RowParallelLayer::new(
101            num_heads * head_dim,
102            hidden_sz,
103            &cfg.quantization_config,
104            bias,
105            comm,
106            vb.pp("o_proj"),
107        )?;
108        let sliding_window = if is_sliding!(layer_idx, cfg) {
109            Some(cfg.sliding_window)
110        } else {
111            None
112        };
113
114        let q_norm = RmsNorm::new_gemma(
115            cfg.head_dim,
116            cfg.rms_norm_eps,
117            mapper.set_device(layer_idx, vb.pp("q_norm"), false),
118        )?;
119        let k_norm = RmsNorm::new_gemma(
120            cfg.head_dim,
121            cfg.rms_norm_eps,
122            mapper.set_device(layer_idx, vb.pp("k_norm"), false),
123        )?;
124        Ok(Self {
125            q_proj,
126            k_proj,
127            v_proj,
128            o_proj,
129            num_heads: num_heads / comm.world_size(),
130            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131            head_dim,
132            rotary_emb_global,
133            rotary_emb_local,
134            use_sliding_window: sliding_window.is_some(),
135            paged_attn,
136            sdpa_params: SdpaParams {
137                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
138                    cfg.num_key_value_heads,
139                    cfg.num_attention_heads,
140                    comm,
141                ),
142                use_flash_attn: cfg.use_flash_attn,
143                softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
144                softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
145                sliding_window,
146            },
147            q_norm,
148            k_norm,
149        })
150    }
151
152    #[allow(clippy::too_many_arguments)]
153    fn forward(
154        &self,
155        xs: &Tensor,
156        attention_mask: Option<&Tensor>,
157        sliding_attention_mask: Option<&Tensor>,
158        seqlen_offsets: &[usize],
159        kv_cache: &mut KvCache,
160        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
161        flash_params: &FlashParams,
162    ) -> Result<Tensor> {
163        let (b_sz, q_len, _) = xs.dims3()?;
164
165        let original_dtype = xs.dtype();
166        let mut xs = xs.clone();
167        if let Some(t) = self.q_proj.quantized_act_type() {
168            xs = xs.to_dtype(t)?;
169        }
170        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
171        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
172        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
173        if self.q_proj.quantized_act_type().is_some() {
174            q = q.to_dtype(original_dtype)?;
175            k = k.to_dtype(original_dtype)?;
176            v = v.to_dtype(original_dtype)?;
177        }
178
179        (q, k, v) = if q_len != 1 {
180            let q = q
181                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
182                .transpose(1, 2)?;
183            let k = k
184                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
185                .transpose(1, 2)?;
186            let v = v
187                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
188                .transpose(1, 2)?;
189            (q, k, v)
190        } else {
191            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
192            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
193            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
194            (q, k, v)
195        };
196
197        q = q.apply(&self.q_norm)?;
198        k = k.apply(&self.k_norm)?;
199
200        (q, k) = match self.use_sliding_window {
201            true => self.rotary_emb_local.forward(&q, &k, seqlen_offsets)?,
202            false => self.rotary_emb_global.forward(&q, &k, seqlen_offsets)?,
203        };
204
205        let mask = if self.use_sliding_window {
206            sliding_attention_mask
207        } else {
208            attention_mask
209        };
210
211        let mut attn_output = match &self.paged_attn {
212            Some(paged_attn) => match metadata {
213                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
214                    &q,
215                    &k,
216                    &v,
217                    attention_mask,
218                    Some(key_cache),
219                    Some(value_cache),
220                    input_metadata,
221                    &self.sdpa_params,
222                    Some(flash_params),
223                )?,
224                None => {
225                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
226                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
227                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
228                    // Sanity check.
229                    assert!(attention_mask.is_some());
230                    paged_attn.forward(
231                        &q,
232                        &k,
233                        &v,
234                        attention_mask,
235                        None,
236                        None,
237                        &input_metadata,
238                        &self.sdpa_params,
239                        Some(flash_params),
240                    )?
241                }
242            },
243            None => {
244                // self.sliding_window is None if !self.use_sliding_window
245                let (k, v) = kv_cache.append(&k, &v)?;
246
247                Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
248            }
249        };
250
251        if let Some(t) = self.q_proj.quantized_act_type() {
252            attn_output = attn_output.to_dtype(t)?;
253        }
254        attn_output = if attention_mask.is_some() {
255            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
256        } else {
257            attn_output.reshape((b_sz, q_len, ()))?
258        };
259        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
260        if self.q_proj.quantized_act_type().is_some() {
261            res = res.to_dtype(original_dtype)?;
262        }
263        Ok(res)
264    }
265}
266
267struct DecoderLayer {
268    self_attn: Attention,
269    mlp: Box<dyn MlpLayer>,
270    input_layernorm: RmsNorm,
271    post_attention_layernorm: RmsNorm,
272    pre_feedforward_layernorm: RmsNorm,
273    post_feedforward_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277    #[allow(clippy::too_many_arguments)]
278    fn new(
279        rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
280        rotary_emb_local: Arc<RotaryEmbedding>,
281        cfg: &Gemma3TextConfig,
282        vb: ShardedVarBuilder,
283        mapper: &dyn DeviceMapper,
284        layer_idx: usize,
285        loading_isq: bool,
286        paged_attn: Option<PagedAttention>,
287        comm: &Arc<mistralrs_quant::Comm>,
288    ) -> Result<Self> {
289        let self_attn = Attention::new(
290            rotary_emb_global,
291            rotary_emb_local,
292            cfg,
293            layer_idx,
294            mapper,
295            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
296            paged_attn,
297            comm,
298        )?;
299        let mlp = Mlp::new(
300            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
301            cfg.hidden_size,
302            cfg.intermediate_size,
303            &cfg.quantization_config,
304            cfg.hidden_activation,
305            comm,
306        )?;
307        let input_layernorm = RmsNorm::new_gemma(
308            cfg.hidden_size,
309            cfg.rms_norm_eps,
310            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
311        )?;
312        let post_attention_layernorm = RmsNorm::new_gemma(
313            cfg.hidden_size,
314            cfg.rms_norm_eps,
315            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
316        )?;
317        let pre_feedforward_layernorm = RmsNorm::new_gemma(
318            cfg.hidden_size,
319            cfg.rms_norm_eps,
320            mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
321        )?;
322        let post_feedforward_layernorm = RmsNorm::new_gemma(
323            cfg.hidden_size,
324            cfg.rms_norm_eps,
325            mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
326        )?;
327        Ok(Self {
328            self_attn,
329            mlp: Box::new(mlp),
330            input_layernorm,
331            post_attention_layernorm,
332            pre_feedforward_layernorm,
333            post_feedforward_layernorm,
334        })
335    }
336
337    #[allow(clippy::too_many_arguments)]
338    fn forward(
339        &self,
340        xs: &Tensor,
341        attention_mask: Option<&Tensor>,
342        sliding_attention_mask: Option<&Tensor>,
343        seqlen_offsets: &[usize],
344        kv_cache: &mut KvCache,
345        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
346        flash_params: &FlashParams,
347    ) -> Result<Tensor> {
348        let residual = xs;
349        let xs = self.input_layernorm.forward(xs)?;
350        let xs = self
351            .self_attn
352            .forward(
353                &xs,
354                attention_mask,
355                sliding_attention_mask,
356                seqlen_offsets,
357                kv_cache,
358                metadata,
359                flash_params,
360            )?
361            .apply(&self.post_attention_layernorm)?;
362        let xs = (xs + residual)?;
363        let residual = &xs;
364        let xs = self
365            .mlp
366            .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
367            .apply(&self.post_feedforward_layernorm)?;
368        residual + xs
369    }
370}
371
372pub struct TextModel {
373    embed_tokens: ScaledEmbedding,
374    layers: Vec<DecoderLayer>,
375    norm: RmsNorm,
376    lm_head: Arc<dyn QuantMethod>,
377    device: Device,
378    cache: EitherCache,
379    max_seq_len: usize,
380    mapper: Box<dyn DeviceMapper + Send + Sync>,
381    sliding_window: usize,
382    final_logit_softcapping: Option<f64>,
383    cfg: ModelConfigMetadata,
384}
385
386impl TextModel {
387    pub fn new(
388        cfg: &Gemma3TextConfig,
389        vb: ShardedVarBuilder,
390        is_gptx: bool,
391        normal_loading_metadata: NormalLoadingMetadata,
392        attention_mechanism: AttentionImplementation,
393    ) -> Result<Self> {
394        if let Some(ref quant_cfg) = &cfg.quantization_config {
395            tracing::info!(
396                "Using {} quantization: {}.",
397                quant_cfg.name(),
398                quant_cfg.get_bits_name(&vb)
399            );
400        }
401        let mapper = normal_loading_metadata.mapper;
402
403        let vb_m = vb.pp("model");
404        let embed_tokens = ScaledEmbedding::new(
405            (cfg.hidden_size as f64).sqrt(),
406            embedding(
407                cfg.vocab_size,
408                cfg.hidden_size,
409                mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
410                &cfg.quantization_config,
411            )?,
412        );
413
414        let mut global_ropes = HashMap::new();
415        for layer_idx in 0..cfg.num_hidden_layers {
416            let device = mapper
417                .device_for(layer_idx, false)
418                .unwrap_or(&normal_loading_metadata.real_device);
419            global_ropes.insert(
420                device.location(),
421                Arc::new(Gemma3RotaryEmbedding::new(
422                    is_gptx,
423                    vb.dtype(),
424                    cfg,
425                    device,
426                )?),
427            );
428        }
429
430        let mut local_ropes = HashMap::new();
431        for layer_idx in 0..cfg.num_hidden_layers {
432            let device = mapper
433                .device_for(layer_idx, false)
434                .unwrap_or(&normal_loading_metadata.real_device);
435            local_ropes.insert(
436                device.location(),
437                Arc::new(RotaryEmbedding::new(
438                    cfg.rope_local_base_freq as f32,
439                    cfg.head_dim,
440                    cfg.max_position_embeddings,
441                    device,
442                    is_gptx,
443                    vb_m.dtype(),
444                )?),
445            );
446        }
447
448        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
449        let vb_l = vb_m.pp("layers");
450        for layer_idx in NiceProgressBar::<_, 'b'>(
451            0..cfg.num_hidden_layers,
452            "Loading repeating layers",
453            &normal_loading_metadata.multi_progress,
454        ) {
455            let device = mapper
456                .device_for(layer_idx, false)
457                .unwrap_or(&normal_loading_metadata.real_device);
458            let rotary_emb_global = global_ropes
459                .get(&device.location())
460                .expect("No RoPE for device location!")
461                .clone();
462            let rotary_emb_local = local_ropes
463                .get(&device.location())
464                .expect("No RoPE for device location!")
465                .clone();
466            let paged_attn = match &attention_mechanism {
467                AttentionImplementation::Eager => None,
468                AttentionImplementation::PagedAttention => {
469                    Some(PagedAttention::new(cfg.head_dim, device, None)?)
470                }
471            };
472            let comm = mapper.get_comm_for(layer_idx)?;
473            let layer = DecoderLayer::new(
474                rotary_emb_global.clone(),
475                rotary_emb_local.clone(),
476                cfg,
477                vb_l.pp(layer_idx),
478                &*mapper,
479                layer_idx,
480                normal_loading_metadata.loading_isq,
481                paged_attn,
482                &comm,
483            )?;
484            layers.push(layer)
485        }
486        let norm = RmsNorm::new_gemma(
487            cfg.hidden_size,
488            cfg.rms_norm_eps,
489            mapper.set_nm_device(vb_m.pp("norm"), false),
490        )?;
491
492        let lm_head = if !cfg.tie_word_embeddings {
493            ReplicatedLayer::new(
494                cfg.hidden_size,
495                cfg.vocab_size,
496                &None,
497                false,
498                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
499            )?
500        } else {
501            ReplicatedLayer::from_linear(candle_nn::Linear::new(
502                mapper.cast_nm_device(
503                    embed_tokens.embeddings(),
504                    normal_loading_metadata.loading_isq,
505                )?,
506                None,
507            ))?
508        };
509        let cache_types = (0..cfg.num_hidden_layers)
510            .map(|layer_idx| {
511                is_sliding!(layer_idx, cfg)
512                    .then(|| NormalCacheType::SlidingWindow {
513                        window: cfg.sliding_window,
514                    })
515                    .unwrap_or(NormalCacheType::Normal {
516                        max_seq_len: cfg.max_position_embeddings,
517                    })
518            })
519            .collect::<Vec<_>>();
520        Ok(Self {
521            embed_tokens,
522            layers,
523            norm,
524            lm_head,
525            device: normal_loading_metadata.real_device,
526            cache: EitherCache::Normal(NormalCache::from_types(cache_types)),
527            max_seq_len: cfg.max_position_embeddings,
528            sliding_window: cfg.sliding_window,
529            final_logit_softcapping: cfg.final_logit_softcapping,
530            cfg: ModelConfigMetadata {
531                max_seq_len: cfg.max_position_embeddings,
532                num_layers: cfg.num_hidden_layers,
533                hidden_size: cfg.hidden_size,
534                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
535                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
536                    .max(1),
537                sliding_window: None,
538                k_head_dim: cfg.head_dim,
539                v_head_dim: cfg.head_dim,
540            },
541            mapper,
542        })
543    }
544
545    pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
546        self.embed_tokens.forward(input_ids)
547    }
548
549    pub fn forward_embeds(
550        &self,
551        input_ids: &Tensor,
552        mut xs: Tensor,
553        seqlen_offsets: &[usize],
554        context_lens: Vec<(usize, usize)>,
555        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
556        flash_params: &FlashParams,
557    ) -> Result<Tensor> {
558        let cache = &mut self.cache.normal().0;
559        let attention_mask = CausalMasker.make_causal_mask_matrix(
560            input_ids,
561            &*cache,
562            xs.dtype(),
563            self.cfg.num_attn_heads,
564        )?;
565        // PagedAttention prompt chunking
566        let attention_mask = attention_mask.filter(|_| {
567            metadata
568                .as_ref()
569                .map(|(_, meta)| meta.is_first_prompt_chunk)
570                .unwrap_or(true)
571        });
572        let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
573            input_ids,
574            &*cache,
575            Some(self.sliding_window),
576            xs.dtype(),
577            self.cfg.num_attn_heads,
578        )?;
579        // PagedAttention prompt chunking
580        let sliding_attention_mask = sliding_attention_mask.filter(|_| {
581            metadata
582                .as_ref()
583                .map(|(_, meta)| meta.is_first_prompt_chunk)
584                .unwrap_or(true)
585        });
586        for (i, layer) in self.layers.iter().enumerate() {
587            xs = self.mapper.map(xs, i)?;
588            xs = layer.forward(
589                &xs,
590                attention_mask
591                    .as_ref()
592                    .map(|m| m.to_device(xs.device()).unwrap())
593                    .as_ref(),
594                sliding_attention_mask
595                    .as_ref()
596                    .map(|m| m.to_device(xs.device()).unwrap())
597                    .as_ref(),
598                seqlen_offsets,
599                &mut cache[i],
600                metadata
601                    .as_ref()
602                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
603                flash_params,
604            )?;
605        }
606        let xs = xs.to_device(&self.device)?;
607        let mut xs = xs.apply(&self.norm)?;
608        if let Some(t) = self.lm_head.quantized_act_type() {
609            xs = xs.to_dtype(t)?;
610        }
611
612        let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
613
614        if let Some(final_logit_softcapping) = self.final_logit_softcapping {
615            xs = (xs / final_logit_softcapping)?;
616            xs = xs.tanh()?;
617            xs = (xs * final_logit_softcapping)?;
618        }
619
620        extract_logits(&xs, context_lens)
621    }
622}
623
624impl IsqModel for TextModel {
625    fn get_layers(
626        &mut self,
627    ) -> (
628        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
629        &dyn DeviceMapper,
630    ) {
631        let mut tensors = Vec::new();
632        tensors.push((&mut self.lm_head, None));
633        for (i, layer) in self.layers.iter_mut().enumerate() {
634            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
635            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
636            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
637            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
638            tensors.extend(
639                layer
640                    .mlp
641                    .get_isq_layers()
642                    .into_iter()
643                    .map(|m| (m, Some(i)))
644                    .collect::<Vec<_>>(),
645            );
646        }
647        (tensors, &*self.mapper)
648    }
649
650    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
651        let uvb = UnVarBuilder::new();
652
653        let uvb_m = uvb.pp("model");
654        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
655        uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
656
657        for (layer_idx, layer) in self.layers.iter().enumerate() {
658            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
659            uvb_l
660                .pp("input_layernorm")
661                .add(&layer.input_layernorm.undo_gemma().unwrap());
662            uvb_l
663                .pp("post_attention_layernorm")
664                .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
665            uvb_l
666                .pp("pre_feedforward_layernorm")
667                .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
668            uvb_l
669                .pp("post_feedforward_layernorm")
670                .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
671        }
672
673        uvb.to_safetensors()
674    }
675
676    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
677        // NOTE: dependant on the exact implementation in get_layers!
678        let mut names = Vec::new();
679        // lm_head
680        names.push(None);
681        for i in 0..self.layers.len() {
682            names.push(Some(format!("blk.{i}.attn_q.weight")));
683            names.push(Some(format!("blk.{i}.attn_k.weight")));
684            names.push(Some(format!("blk.{i}.attn_v.weight")));
685            names.push(Some(format!("blk.{i}.attn_output.weight")));
686            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
687            names.push(Some(format!("blk.{i}.ffn_up.weight")));
688            names.push(Some(format!("blk.{i}.ffn_down.weight")));
689        }
690        Ok(names)
691    }
692}
693
694impl VisionModel for TextModel {
695    fn forward(
696        &self,
697        _input_ids: &Tensor,
698        _pixel_values: Option<Tensor>,
699        _seqlen_offsets: &[usize],
700        _context_lens: Vec<(usize, usize)>,
701        _position_ids: Vec<usize>,
702        _model_specific_args: Box<dyn std::any::Any>, // pixel attention mask, or image sizes, or anything else
703        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
704        _flash_params: &FlashParams,
705    ) -> candle_core::Result<Tensor> {
706        unreachable!()
707    }
708    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
709        unreachable!()
710    }
711    fn cache(&self) -> &EitherCache {
712        &self.cache
713    }
714    fn cache_mut(&mut self) -> &mut EitherCache {
715        &mut self.cache
716    }
717    fn device(&self) -> &Device {
718        &self.device
719    }
720    fn max_seq_len(&self) -> usize {
721        self.max_seq_len
722    }
723    fn config(&self) -> &ModelConfigMetadata {
724        &self.cfg
725    }
726    fn has_conv2d(&self) -> bool {
727        unreachable!()
728    }
729}
730
731impl AnyMoeBaseModelMixin for TextModel {
732    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
733        let mut mlps = Vec::new();
734        for layer in &self.layers {
735            mlps.push(&*layer.mlp);
736        }
737        mlps
738    }
739    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
740        let mut mlps = Vec::new();
741        for layer in &mut self.layers {
742            mlps.push(&mut layer.mlp);
743        }
744        mlps
745    }
746    fn create_anymoe_layers(
747        &mut self,
748        additional_vbs: Vec<ShardedVarBuilder>,
749        config: AnyMoeConfig,
750        (prefix, mlp): (String, String),
751        mut layers: Vec<usize>,
752        expert_type: AnyMoeExpertType,
753        gate_vb: Option<ShardedVarBuilder>,
754    ) -> Result<()> {
755        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
756        if layers.is_empty() {
757            layers = (0..self.layers.len()).collect::<Vec<_>>();
758        }
759        for _ in 0..layers.len() {
760            experts.push(Vec::new());
761        }
762        for vb in additional_vbs {
763            let vb = vb.pp(&prefix);
764            for (layer, row) in experts.iter_mut().enumerate() {
765                if !layers.contains(&layer) {
766                    continue;
767                }
768
769                let intermediate_size = self.layers[layer].mlp.get_params()[1];
770                let hidden_size = self.layers[layer].mlp.get_params()[0];
771                match expert_type {
772                    AnyMoeExpertType::FineTuned => {
773                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
774                        row.push(Box::new(Mlp::replicate(
775                            self.layers[layer].mlp.get_params(),
776                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
777                            self.layers[layer].mlp.hidden_act(),
778                            &self.mapper.get_comm_for(layer)?,
779                        )?));
780                    }
781                    AnyMoeExpertType::LoraAdapter {
782                        rank,
783                        alpha,
784                        ref target_modules,
785                    } => {
786                        let vb_mlp = vb.pp(layer).pp(&mlp);
787
788                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
789                            Some(get_delta_from_lora_ab!(
790                                vb_mlp,
791                                rank,
792                                alpha,
793                                (hidden_size, intermediate_size),
794                                "gate_proj"
795                            ))
796                        } else {
797                            None
798                        };
799                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
800                            Some(get_delta_from_lora_ab!(
801                                vb_mlp,
802                                rank,
803                                alpha,
804                                (hidden_size, intermediate_size),
805                                "up_proj"
806                            ))
807                        } else {
808                            None
809                        };
810                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
811                            Some(get_delta_from_lora_ab!(
812                                vb_mlp,
813                                rank,
814                                alpha,
815                                (intermediate_size, hidden_size),
816                                "down_proj"
817                            ))
818                        } else {
819                            None
820                        };
821
822                        row.push(self.layers[layer].mlp.new_added_delta(vec![
823                            gate_proj_delta,
824                            up_proj_delta,
825                            down_proj_delta,
826                        ])?);
827                    }
828                }
829            }
830        }
831        for (layer, expert) in layers.into_iter().zip(experts) {
832            let mut experts_all = vec![self.layers[layer].mlp.clone()];
833            experts_all.extend(expert);
834            let (dtype, device) = self.layers[layer].mlp.dtype_device();
835            self.layers[layer].mlp = Box::new(MoeMlp::new(
836                experts_all,
837                config.clone(),
838                dtype,
839                &device,
840                layer,
841                gate_vb.as_ref(),
842            )?);
843        }
844        Ok(())
845    }
846    fn amoe_supported(&self) -> bool {
847        true
848    }
849}