mistralrs_core/vision_models/gemma3/
mmproj.rs

1use candle_core::{Result, Tensor};
2use candle_nn::Module;
3use mistralrs_quant::ShardedVarBuilder;
4
5use crate::{
6    layers::{AvgPool2d, RmsNorm},
7    utils::unvarbuilder::UnVarBuilder,
8};
9
10use super::config::Gemma3Config;
11
12pub struct Gemma3MultiModalProjector {
13    mm_input_projection_weight: Tensor,
14    mm_soft_emb_norm: RmsNorm,
15    patches_per_image: usize,
16    avg_pool: AvgPool2d,
17}
18
19impl Gemma3MultiModalProjector {
20    pub fn new(cfg: &Gemma3Config, vb: ShardedVarBuilder) -> Result<Self> {
21        let Gemma3Config::WithVision {
22            text_config,
23            vision_config,
24            image_token_index: _,
25            mm_tokens_per_image,
26        } = cfg
27        else {
28            unreachable!()
29        };
30
31        let mm_input_projection_weight = vb.get(
32            (vision_config.hidden_size, text_config.hidden_size),
33            "mm_input_projection_weight",
34        )?;
35        let mm_soft_emb_norm = RmsNorm::new_gemma(
36            vision_config.hidden_size,
37            vision_config.layer_norm_eps,
38            vb.pp("mm_soft_emb_norm"),
39        )?;
40
41        let patches_per_image = vision_config.image_size / vision_config.patch_size;
42        let tokens_per_side = mm_tokens_per_image.isqrt();
43        let kernel_size = patches_per_image / tokens_per_side;
44        let avg_pool = AvgPool2d::new(kernel_size, kernel_size);
45
46        Ok(Self {
47            mm_input_projection_weight,
48            mm_soft_emb_norm,
49            patches_per_image,
50            avg_pool,
51        })
52    }
53
54    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55        let (bs, _, seqlen) = xs.dims3()?;
56
57        let mut reshaped_vision_outputs = xs.transpose(1, 2)?;
58        reshaped_vision_outputs = reshaped_vision_outputs.reshape((
59            bs,
60            seqlen,
61            self.patches_per_image,
62            self.patches_per_image,
63        ))?;
64        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()?;
65
66        let mut pooled_vision_outputs = self.avg_pool.forward(&reshaped_vision_outputs)?;
67        pooled_vision_outputs = pooled_vision_outputs.flatten_from(2)?;
68        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)?;
69
70        let normed_vision_outputs = self.mm_soft_emb_norm.forward(&pooled_vision_outputs)?;
71
72        normed_vision_outputs.broadcast_matmul(&self.mm_input_projection_weight)
73    }
74
75    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
76        let uvb = UnVarBuilder::new();
77
78        uvb.add_tensor(
79            "mm_input_projection_weight",
80            self.mm_input_projection_weight.clone(),
81        );
82        uvb.pp("mm_soft_emb_norm")
83            .add(&self.mm_soft_emb_norm.undo_gemma().unwrap());
84
85        uvb.to_safetensors()
86    }
87}