mistralrs_core/vision_models/gemma3/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{Context, DType, Device, Result, Tensor, D};
6use config::Gemma3Config;
7use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
8use mmproj::Gemma3MultiModalProjector;
9use text::TextModel;
10
11use crate::{
12 amoe::{AnyMoeBaseModelMixin, MlpLayer},
13 device_map::DeviceMapper,
14 paged_attention::{AttentionImplementation, ModelConfigMetadata},
15 pipeline::{
16 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17 EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
18 },
19 utils::unvarbuilder::UnVarBuilder,
20 AnyMoeConfig, AnyMoeExpertType,
21};
22
23pub mod config;
24mod inputs_processor;
25mod mmproj;
26mod text;
27pub(crate) use inputs_processor::Gemma3Processor;
28
29use super::siglip::SiglipVisionTransformer;
30
31pub struct Gemma3Model {
32 language_model: TextModel,
33 multi_modal_projector: Option<Gemma3MultiModalProjector>,
34 vision_tower: Option<SiglipVisionTransformer>,
35 cfg: Gemma3Config,
36}
37
38impl Gemma3Model {
39 pub fn new(
40 cfg: &Gemma3Config,
41 vb: ShardedVarBuilder,
42 is_gptx: bool,
43 normal_loading_metadata: NormalLoadingMetadata,
44 attention_mechanism: AttentionImplementation,
45 ) -> Result<Self> {
46 match cfg {
47 Gemma3Config::Text(text_cfg) => Ok(Self {
48 language_model: TextModel::new(
49 text_cfg,
50 vb,
51 is_gptx,
52 normal_loading_metadata,
53 attention_mechanism,
54 )?,
55 multi_modal_projector: None,
56 vision_tower: None,
57 cfg: cfg.clone(),
58 }),
59 Gemma3Config::WithVision {
60 text_config,
61 vision_config,
62 image_token_index,
63 mm_tokens_per_image: _,
64 } => {
65 assert!(*image_token_index < text_config.vocab_size);
66 Ok(Self {
67 multi_modal_projector: Some(Gemma3MultiModalProjector::new(
68 cfg,
69 vb.pp("multi_modal_projector")
70 .set_device(normal_loading_metadata.real_device.clone()),
71 )?),
72 vision_tower: Some(SiglipVisionTransformer::new(
73 vision_config,
74 vb.pp("vision_tower")
75 .pp("vision_model")
76 .set_device(normal_loading_metadata.real_device.clone()),
77 )?),
78 language_model: TextModel::new(
79 text_config,
80 vb.pp("language_model"),
81 is_gptx,
82 normal_loading_metadata,
83 attention_mechanism,
84 )?,
85 cfg: cfg.clone(),
86 })
87 }
88 }
89 }
90
91 fn forward(
92 &self,
93 input_ids: &Tensor,
94 pixel_values: Option<Tensor>,
95 seqlen_offsets: &[usize],
96 context_lens: Vec<(usize, usize)>,
97 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
98 flash_params: &FlashParams,
99 ) -> Result<Tensor> {
100 let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
101 if let Some(pixel_values) = pixel_values {
102 let Gemma3Config::WithVision {
103 image_token_index, ..
104 } = &self.cfg
105 else {
106 unreachable!()
107 };
108 let special_image_mask = input_ids
109 .eq(*image_token_index as f64)?
110 .unsqueeze(D::Minus1)?
111 .broadcast_as(input_embeds.shape())?
112 .to_dtype(DType::U32)?;
113
114 let mask_flat = special_image_mask.flatten_all()?;
115 let indices = mask_flat.nonzero()?.squeeze(1)?;
117
118 let vision_tower = self
119 .vision_tower
120 .as_ref()
121 .context("This model does not support vision.")?;
122 let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
123 let dtype = vision_tower.dtype();
124 let vision_outputs =
125 vision_tower.forward(&pixel_values.to_dtype(dtype)?, None, None)?;
126 let image_features = multi_modal_projector.forward(&vision_outputs)?;
127
128 let mut x_flat = input_embeds.flatten_all()?;
129 let src_flat = image_features.flatten_all()?;
130
131 let current_vals = x_flat.gather(&indices, 0)?;
132 let diff = (src_flat - current_vals)?;
133 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
134
135 input_embeds = x_flat.reshape(input_embeds.shape())?;
136 };
137 self.language_model.forward_embeds(
138 input_ids,
139 input_embeds,
140 seqlen_offsets,
141 context_lens,
142 metadata,
143 flash_params,
144 )
145 }
146}
147
148impl IsqModel for Gemma3Model {
149 fn get_layers(
150 &mut self,
151 ) -> (
152 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
153 &dyn DeviceMapper,
154 ) {
155 self.language_model.get_layers()
156 }
157
158 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
159 match &self.cfg {
160 Gemma3Config::Text(_) => self.language_model.residual_tensors(),
161 Gemma3Config::WithVision { .. } => {
162 let vision_tower = self.vision_tower.as_ref().unwrap();
163 let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
164
165 let uvb = UnVarBuilder::new();
166 uvb.pp("multi_modal_projector")
167 .extend(multi_modal_projector.residual_tensors());
168 uvb.pp("language_model")
169 .extend(self.language_model.residual_tensors());
170 uvb.pp("vision_tower")
171 .pp("vision_model")
172 .extend(vision_tower.residual_tensors());
173
174 uvb.to_safetensors()
175 }
176 }
177 }
178
179 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
180 self.language_model.imatrix_names()
181 }
182}
183
184pub struct Gemma3SpecificArgs;
185
186impl VisionModel for Gemma3Model {
187 fn forward(
188 &self,
189 input_ids: &Tensor,
190 pixel_values: Option<Tensor>,
191 seqlen_offsets: &[usize],
192 context_lens: Vec<(usize, usize)>,
193 _position_ids: Vec<usize>,
194 _model_specific_args: Box<dyn std::any::Any>,
195 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
196 flash_params: &FlashParams,
197 ) -> candle_core::Result<Tensor> {
198 self.forward(
199 input_ids,
200 pixel_values,
201 seqlen_offsets,
202 context_lens,
203 metadata,
204 flash_params,
205 )
206 }
207 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
208 Box::new(Gemma3SpecificArgs)
209 }
210 fn cache(&self) -> &EitherCache {
211 self.language_model.cache()
212 }
213 fn cache_mut(&mut self) -> &mut EitherCache {
214 self.language_model.cache_mut()
215 }
216 fn device(&self) -> &Device {
217 self.language_model.device()
218 }
219 fn max_seq_len(&self) -> usize {
220 self.language_model.max_seq_len()
221 }
222 fn config(&self) -> &ModelConfigMetadata {
223 self.language_model.config()
224 }
225 fn has_conv2d(&self) -> bool {
226 false
228 }
229}
230
231impl AnyMoeBaseModelMixin for Gemma3Model {
232 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
233 self.language_model.get_mlps()
234 }
235 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
236 self.language_model.get_mlps_mut()
237 }
238 fn create_anymoe_layers(
239 &mut self,
240 additional_vbs: Vec<ShardedVarBuilder>,
241 config: AnyMoeConfig,
242 (prefix, mlp): (String, String),
243 layers: Vec<usize>,
244 expert_type: AnyMoeExpertType,
245 gate_vb: Option<ShardedVarBuilder>,
246 ) -> Result<()> {
247 self.language_model.create_anymoe_layers(
248 additional_vbs,
249 config,
250 (prefix, mlp),
251 layers,
252 expert_type,
253 gate_vb,
254 )
255 }
256 fn amoe_supported(&self) -> bool {
257 true
258 }
259}