mistralrs_core/vision_models/gemma3/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{Context, DType, Device, Result, Tensor, D};
6use config::Gemma3Config;
7use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
8use mmproj::Gemma3MultiModalProjector;
9use text::TextModel;
10
11use crate::{
12 amoe::{AnyMoeBaseModelMixin, MlpLayer},
13 device_map::DeviceMapper,
14 ops::NonZeroOp,
15 paged_attention::{AttentionImplementation, ModelConfigMetadata},
16 pipeline::{
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19 },
20 utils::unvarbuilder::UnVarBuilder,
21 AnyMoeConfig, AnyMoeExpertType,
22};
23
24pub mod config;
25mod inputs_processor;
26mod mmproj;
27mod text;
28pub(crate) use inputs_processor::Gemma3Processor;
29
30use super::siglip::SiglipVisionTransformer;
31
32pub struct Gemma3Model {
33 language_model: TextModel,
34 multi_modal_projector: Option<Gemma3MultiModalProjector>,
35 vision_tower: Option<SiglipVisionTransformer>,
36 cfg: Gemma3Config,
37}
38
39impl Gemma3Model {
40 pub fn new(
41 cfg: &Gemma3Config,
42 vb: ShardedVarBuilder,
43 is_gptx: bool,
44 normal_loading_metadata: NormalLoadingMetadata,
45 attention_mechanism: AttentionImplementation,
46 ) -> Result<Self> {
47 match cfg {
48 Gemma3Config::Text(text_cfg) => Ok(Self {
49 language_model: TextModel::new(
50 text_cfg,
51 vb,
52 is_gptx,
53 normal_loading_metadata,
54 attention_mechanism,
55 )?,
56 multi_modal_projector: None,
57 vision_tower: None,
58 cfg: cfg.clone(),
59 }),
60 Gemma3Config::WithVision {
61 text_config,
62 vision_config,
63 image_token_index,
64 mm_tokens_per_image: _,
65 } => {
66 assert!(*image_token_index < text_config.vocab_size);
67 Ok(Self {
68 multi_modal_projector: Some(Gemma3MultiModalProjector::new(
69 cfg,
70 vb.pp("multi_modal_projector")
71 .set_device(normal_loading_metadata.real_device.clone()),
72 )?),
73 vision_tower: Some(SiglipVisionTransformer::new(
74 vision_config,
75 vb.pp("vision_tower")
76 .pp("vision_model")
77 .set_device(normal_loading_metadata.real_device.clone()),
78 )?),
79 language_model: TextModel::new(
80 text_config,
81 vb.pp("language_model"),
82 is_gptx,
83 normal_loading_metadata,
84 attention_mechanism,
85 )?,
86 cfg: cfg.clone(),
87 })
88 }
89 }
90 }
91
92 fn forward(
93 &self,
94 input_ids: &Tensor,
95 pixel_values: Option<Tensor>,
96 seqlen_offsets: &[usize],
97 context_lens: Vec<(usize, usize)>,
98 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
99 flash_params: &FlashParams,
100 ) -> Result<Tensor> {
101 let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
102 if let Some(pixel_values) = pixel_values {
103 let vision_tower = self
104 .vision_tower
105 .as_ref()
106 .context("This model does not support vision.")?;
107 let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
108 let Gemma3Config::WithVision {
109 image_token_index, ..
110 } = &self.cfg
111 else {
112 unreachable!()
113 };
114
115 let dtype = vision_tower.dtype();
116 let vision_outputs =
117 vision_tower.forward(&pixel_values.to_dtype(dtype)?, None, None)?;
118 let image_features = multi_modal_projector.forward(&vision_outputs)?;
119
120 let special_image_mask = input_ids
121 .eq(*image_token_index as f64)?
122 .unsqueeze(D::Minus1)?
123 .broadcast_as(input_embeds.shape())?
124 .to_dtype(DType::U32)?;
125
126 let mask_flat = special_image_mask.flatten_all()?;
127 let mut x_flat = input_embeds.flatten_all()?;
128 let src_flat = image_features.flatten_all()?;
129
130 let indices = mask_flat.nonzero()?.squeeze(1)?;
131 let current_vals = x_flat.gather(&indices, 0)?;
132 let diff = (src_flat - current_vals)?;
133 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
134
135 input_embeds = x_flat.reshape(input_embeds.shape())?;
136 };
137 self.language_model.forward_embeds(
138 input_ids,
139 input_embeds,
140 seqlen_offsets,
141 context_lens,
142 metadata,
143 flash_params,
144 )
145 }
146}
147
148impl IsqModel for Gemma3Model {
149 fn get_layers(
150 &mut self,
151 ) -> (
152 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
153 &dyn DeviceMapper,
154 ) {
155 self.language_model.get_layers()
156 }
157
158 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
159 match &self.cfg {
160 Gemma3Config::Text(_) => self.language_model.residual_tensors(),
161 Gemma3Config::WithVision { .. } => {
162 let vision_tower = self.vision_tower.as_ref().unwrap();
163 let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
164
165 let uvb = UnVarBuilder::new();
166 uvb.pp("multi_modal_projector")
167 .extend(multi_modal_projector.residual_tensors());
168 uvb.pp("language_model")
169 .extend(self.language_model.residual_tensors());
170 uvb.pp("vision_tower")
171 .pp("vision_model")
172 .extend(vision_tower.residual_tensors());
173
174 uvb.to_safetensors()
175 }
176 }
177 }
178
179 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
180 self.language_model.imatrix_names()
181 }
182}
183
184pub struct Gemma3SpecificArgs;
185
186impl VisionModel for Gemma3Model {
187 fn forward(
188 &self,
189 input_ids: &Tensor,
190 pixel_values: Option<Tensor>,
191 seqlen_offsets: &[usize],
192 context_lens: Vec<(usize, usize)>,
193 _position_ids: Vec<usize>,
194 _model_specific_args: Box<dyn std::any::Any>,
195 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
196 flash_params: &FlashParams,
197 ) -> candle_core::Result<Tensor> {
198 self.forward(
199 input_ids,
200 pixel_values,
201 seqlen_offsets,
202 context_lens,
203 metadata,
204 flash_params,
205 )
206 }
207 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
208 Box::new(Gemma3SpecificArgs)
209 }
210 fn cache(&self) -> &EitherCache {
211 self.language_model.cache()
212 }
213 fn cache_mut(&mut self) -> &mut EitherCache {
214 self.language_model.cache_mut()
215 }
216 fn device(&self) -> &Device {
217 self.language_model.device()
218 }
219 fn max_seq_len(&self) -> usize {
220 self.language_model.max_seq_len()
221 }
222 fn config(&self) -> &ModelConfigMetadata {
223 self.language_model.config()
224 }
225 fn has_conv2d(&self) -> bool {
226 false
228 }
229}
230
231impl AnyMoeBaseModelMixin for Gemma3Model {
232 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
233 self.language_model.get_mlps()
234 }
235 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
236 self.language_model.get_mlps_mut()
237 }
238 fn create_anymoe_layers(
239 &mut self,
240 additional_vbs: Vec<ShardedVarBuilder>,
241 config: AnyMoeConfig,
242 (prefix, mlp): (String, String),
243 layers: Vec<usize>,
244 expert_type: AnyMoeExpertType,
245 gate_vb: Option<ShardedVarBuilder>,
246 ) -> Result<()> {
247 self.language_model.create_anymoe_layers(
248 additional_vbs,
249 config,
250 (prefix, mlp),
251 layers,
252 expert_type,
253 gate_vb,
254 )
255 }
256 fn amoe_supported(&self) -> bool {
257 true
258 }
259}