mistralrs_core/vision_models/gemma3/
mmproj.rs
1use candle_core::{Result, Tensor};
2use candle_nn::Module;
3use mistralrs_quant::ShardedVarBuilder;
4
5use crate::{
6 layers::{AvgPool2d, RmsNorm},
7 utils::unvarbuilder::UnVarBuilder,
8};
9
10use super::config::Gemma3Config;
11
12pub struct Gemma3MultiModalProjector {
13 mm_input_projection_weight: Tensor,
14 mm_soft_emb_norm: RmsNorm,
15 patches_per_image: usize,
16 avg_pool: AvgPool2d,
17}
18
19impl Gemma3MultiModalProjector {
20 pub fn new(cfg: &Gemma3Config, vb: ShardedVarBuilder) -> Result<Self> {
21 let Gemma3Config::WithVision {
22 text_config,
23 vision_config,
24 image_token_index: _,
25 mm_tokens_per_image,
26 } = cfg
27 else {
28 unreachable!()
29 };
30
31 let mm_input_projection_weight = vb.get(
32 (vision_config.hidden_size, text_config.hidden_size),
33 "mm_input_projection_weight",
34 )?;
35 let mm_soft_emb_norm = RmsNorm::new_gemma(
36 vision_config.hidden_size,
37 vision_config.layer_norm_eps,
38 vb.pp("mm_soft_emb_norm"),
39 )?;
40
41 let patches_per_image = vision_config.image_size / vision_config.patch_size;
42 let tokens_per_side = mm_tokens_per_image.isqrt();
43 let kernel_size = patches_per_image / tokens_per_side;
44 let avg_pool = AvgPool2d::new(kernel_size, kernel_size);
45
46 Ok(Self {
47 mm_input_projection_weight,
48 mm_soft_emb_norm,
49 patches_per_image,
50 avg_pool,
51 })
52 }
53
54 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55 let (bs, _, seqlen) = xs.dims3()?;
56
57 let mut reshaped_vision_outputs = xs.transpose(1, 2)?;
58 reshaped_vision_outputs = reshaped_vision_outputs.reshape((
59 bs,
60 seqlen,
61 self.patches_per_image,
62 self.patches_per_image,
63 ))?;
64 reshaped_vision_outputs = reshaped_vision_outputs.contiguous()?;
65
66 let mut pooled_vision_outputs = self.avg_pool.forward(&reshaped_vision_outputs)?;
67 pooled_vision_outputs = pooled_vision_outputs.flatten_from(2)?;
68 pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)?;
69
70 let normed_vision_outputs = self.mm_soft_emb_norm.forward(&pooled_vision_outputs)?;
71
72 normed_vision_outputs.broadcast_matmul(&self.mm_input_projection_weight)
73 }
74
75 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
76 let uvb = UnVarBuilder::new();
77
78 uvb.add_tensor(
79 "mm_input_projection_weight",
80 self.mm_input_projection_weight.clone(),
81 );
82 uvb.pp("mm_soft_emb_norm")
83 .add(&self.mm_soft_emb_norm.undo_gemma().unwrap());
84
85 uvb.to_safetensors()
86 }
87}