mistralrs_core/vision_models/gemma3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{Context, DType, Device, Result, Tensor, D};
6use config::Gemma3Config;
7use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
8use mmproj::Gemma3MultiModalProjector;
9use text::TextModel;
10
11use crate::{
12    amoe::{AnyMoeBaseModelMixin, MlpLayer},
13    device_map::DeviceMapper,
14    paged_attention::{AttentionImplementation, ModelConfigMetadata},
15    pipeline::{
16        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
17        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
18    },
19    utils::unvarbuilder::UnVarBuilder,
20    AnyMoeConfig, AnyMoeExpertType,
21};
22
23pub mod config;
24mod inputs_processor;
25mod mmproj;
26mod text;
27pub(crate) use inputs_processor::Gemma3Processor;
28
29use super::siglip::SiglipVisionTransformer;
30
31pub struct Gemma3Model {
32    language_model: TextModel,
33    multi_modal_projector: Option<Gemma3MultiModalProjector>,
34    vision_tower: Option<SiglipVisionTransformer>,
35    cfg: Gemma3Config,
36}
37
38impl Gemma3Model {
39    pub fn new(
40        cfg: &Gemma3Config,
41        vb: ShardedVarBuilder,
42        is_gptx: bool,
43        normal_loading_metadata: NormalLoadingMetadata,
44        attention_mechanism: AttentionImplementation,
45    ) -> Result<Self> {
46        match cfg {
47            Gemma3Config::Text(text_cfg) => Ok(Self {
48                language_model: TextModel::new(
49                    text_cfg,
50                    vb,
51                    is_gptx,
52                    normal_loading_metadata,
53                    attention_mechanism,
54                )?,
55                multi_modal_projector: None,
56                vision_tower: None,
57                cfg: cfg.clone(),
58            }),
59            Gemma3Config::WithVision {
60                text_config,
61                vision_config,
62                image_token_index,
63                mm_tokens_per_image: _,
64            } => {
65                assert!(*image_token_index < text_config.vocab_size);
66                Ok(Self {
67                    multi_modal_projector: Some(Gemma3MultiModalProjector::new(
68                        cfg,
69                        vb.pp("multi_modal_projector")
70                            .set_device(normal_loading_metadata.real_device.clone()),
71                    )?),
72                    vision_tower: Some(SiglipVisionTransformer::new(
73                        vision_config,
74                        vb.pp("vision_tower")
75                            .pp("vision_model")
76                            .set_device(normal_loading_metadata.real_device.clone()),
77                    )?),
78                    language_model: TextModel::new(
79                        text_config,
80                        vb.pp("language_model"),
81                        is_gptx,
82                        normal_loading_metadata,
83                        attention_mechanism,
84                    )?,
85                    cfg: cfg.clone(),
86                })
87            }
88        }
89    }
90
91    fn forward(
92        &self,
93        input_ids: &Tensor,
94        pixel_values: Option<Tensor>,
95        seqlen_offsets: &[usize],
96        context_lens: Vec<(usize, usize)>,
97        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
98        flash_params: &FlashParams,
99    ) -> Result<Tensor> {
100        let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
101        if let Some(pixel_values) = pixel_values {
102            let Gemma3Config::WithVision {
103                image_token_index, ..
104            } = &self.cfg
105            else {
106                unreachable!()
107            };
108            let special_image_mask = input_ids
109                .eq(*image_token_index as f64)?
110                .unsqueeze(D::Minus1)?
111                .broadcast_as(input_embeds.shape())?
112                .to_dtype(DType::U32)?;
113
114            let mask_flat = special_image_mask.flatten_all()?;
115            // Nonzero before vision model to allow async processing all the way through logits.
116            let indices = mask_flat.nonzero()?.squeeze(1)?;
117
118            let vision_tower = self
119                .vision_tower
120                .as_ref()
121                .context("This model does not support vision.")?;
122            let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
123            let dtype = vision_tower.dtype();
124            let vision_outputs =
125                vision_tower.forward(&pixel_values.to_dtype(dtype)?, None, None)?;
126            let image_features = multi_modal_projector.forward(&vision_outputs)?;
127
128            let mut x_flat = input_embeds.flatten_all()?;
129            let src_flat = image_features.flatten_all()?;
130
131            let current_vals = x_flat.gather(&indices, 0)?;
132            let diff = (src_flat - current_vals)?;
133            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
134
135            input_embeds = x_flat.reshape(input_embeds.shape())?;
136        };
137        self.language_model.forward_embeds(
138            input_ids,
139            input_embeds,
140            seqlen_offsets,
141            context_lens,
142            metadata,
143            flash_params,
144        )
145    }
146}
147
148impl IsqModel for Gemma3Model {
149    fn get_layers(
150        &mut self,
151    ) -> (
152        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
153        &dyn DeviceMapper,
154    ) {
155        self.language_model.get_layers()
156    }
157
158    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
159        match &self.cfg {
160            Gemma3Config::Text(_) => self.language_model.residual_tensors(),
161            Gemma3Config::WithVision { .. } => {
162                let vision_tower = self.vision_tower.as_ref().unwrap();
163                let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
164
165                let uvb = UnVarBuilder::new();
166                uvb.pp("multi_modal_projector")
167                    .extend(multi_modal_projector.residual_tensors());
168                uvb.pp("language_model")
169                    .extend(self.language_model.residual_tensors());
170                uvb.pp("vision_tower")
171                    .pp("vision_model")
172                    .extend(vision_tower.residual_tensors());
173
174                uvb.to_safetensors()
175            }
176        }
177    }
178
179    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
180        self.language_model.imatrix_names()
181    }
182}
183
184pub struct Gemma3SpecificArgs;
185
186impl VisionModel for Gemma3Model {
187    fn forward(
188        &self,
189        input_ids: &Tensor,
190        pixel_values: Option<Tensor>,
191        seqlen_offsets: &[usize],
192        context_lens: Vec<(usize, usize)>,
193        _position_ids: Vec<usize>,
194        _model_specific_args: Box<dyn std::any::Any>,
195        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
196        flash_params: &FlashParams,
197    ) -> candle_core::Result<Tensor> {
198        self.forward(
199            input_ids,
200            pixel_values,
201            seqlen_offsets,
202            context_lens,
203            metadata,
204            flash_params,
205        )
206    }
207    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
208        Box::new(Gemma3SpecificArgs)
209    }
210    fn cache(&self) -> &EitherCache {
211        self.language_model.cache()
212    }
213    fn cache_mut(&mut self) -> &mut EitherCache {
214        self.language_model.cache_mut()
215    }
216    fn device(&self) -> &Device {
217        self.language_model.device()
218    }
219    fn max_seq_len(&self) -> usize {
220        self.language_model.max_seq_len()
221    }
222    fn config(&self) -> &ModelConfigMetadata {
223        self.language_model.config()
224    }
225    fn has_conv2d(&self) -> bool {
226        // TODO
227        false
228    }
229}
230
231impl AnyMoeBaseModelMixin for Gemma3Model {
232    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
233        self.language_model.get_mlps()
234    }
235    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
236        self.language_model.get_mlps_mut()
237    }
238    fn create_anymoe_layers(
239        &mut self,
240        additional_vbs: Vec<ShardedVarBuilder>,
241        config: AnyMoeConfig,
242        (prefix, mlp): (String, String),
243        layers: Vec<usize>,
244        expert_type: AnyMoeExpertType,
245        gate_vb: Option<ShardedVarBuilder>,
246    ) -> Result<()> {
247        self.language_model.create_anymoe_layers(
248            additional_vbs,
249            config,
250            (prefix, mlp),
251            layers,
252            expert_type,
253            gate_vb,
254        )
255    }
256    fn amoe_supported(&self) -> bool {
257        true
258    }
259}