mistralrs_core/vision_models/gemma3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{Context, DType, Device, Result, Tensor, D};
6use config::Gemma3Config;
7use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
8use mmproj::Gemma3MultiModalProjector;
9use text::TextModel;
10
11use crate::{
12    amoe::{AnyMoeBaseModelMixin, MlpLayer},
13    device_map::DeviceMapper,
14    ops::NonZeroOp,
15    paged_attention::{AttentionImplementation, ModelConfigMetadata},
16    pipeline::{
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19    },
20    utils::unvarbuilder::UnVarBuilder,
21    AnyMoeConfig, AnyMoeExpertType,
22};
23
24pub mod config;
25mod inputs_processor;
26mod mmproj;
27mod text;
28pub(crate) use inputs_processor::Gemma3Processor;
29
30use super::siglip::SiglipVisionTransformer;
31
32pub struct Gemma3Model {
33    language_model: TextModel,
34    multi_modal_projector: Option<Gemma3MultiModalProjector>,
35    vision_tower: Option<SiglipVisionTransformer>,
36    cfg: Gemma3Config,
37}
38
39impl Gemma3Model {
40    pub fn new(
41        cfg: &Gemma3Config,
42        vb: ShardedVarBuilder,
43        is_gptx: bool,
44        normal_loading_metadata: NormalLoadingMetadata,
45        attention_mechanism: AttentionImplementation,
46    ) -> Result<Self> {
47        match cfg {
48            Gemma3Config::Text(text_cfg) => Ok(Self {
49                language_model: TextModel::new(
50                    text_cfg,
51                    vb,
52                    is_gptx,
53                    normal_loading_metadata,
54                    attention_mechanism,
55                )?,
56                multi_modal_projector: None,
57                vision_tower: None,
58                cfg: cfg.clone(),
59            }),
60            Gemma3Config::WithVision {
61                text_config,
62                vision_config,
63                image_token_index,
64                mm_tokens_per_image: _,
65            } => {
66                assert!(*image_token_index < text_config.vocab_size);
67                Ok(Self {
68                    multi_modal_projector: Some(Gemma3MultiModalProjector::new(
69                        cfg,
70                        vb.pp("multi_modal_projector")
71                            .set_device(normal_loading_metadata.real_device.clone()),
72                    )?),
73                    vision_tower: Some(SiglipVisionTransformer::new(
74                        vision_config,
75                        vb.pp("vision_tower")
76                            .pp("vision_model")
77                            .set_device(normal_loading_metadata.real_device.clone()),
78                    )?),
79                    language_model: TextModel::new(
80                        text_config,
81                        vb.pp("language_model"),
82                        is_gptx,
83                        normal_loading_metadata,
84                        attention_mechanism,
85                    )?,
86                    cfg: cfg.clone(),
87                })
88            }
89        }
90    }
91
92    fn forward(
93        &self,
94        input_ids: &Tensor,
95        pixel_values: Option<Tensor>,
96        seqlen_offsets: &[usize],
97        context_lens: Vec<(usize, usize)>,
98        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
99        flash_params: &FlashParams,
100    ) -> Result<Tensor> {
101        let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
102        if let Some(pixel_values) = pixel_values {
103            let vision_tower = self
104                .vision_tower
105                .as_ref()
106                .context("This model does not support vision.")?;
107            let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
108            let Gemma3Config::WithVision {
109                image_token_index, ..
110            } = &self.cfg
111            else {
112                unreachable!()
113            };
114
115            let dtype = vision_tower.dtype();
116            let vision_outputs =
117                vision_tower.forward(&pixel_values.to_dtype(dtype)?, None, None)?;
118            let image_features = multi_modal_projector.forward(&vision_outputs)?;
119
120            let special_image_mask = input_ids
121                .eq(*image_token_index as f64)?
122                .unsqueeze(D::Minus1)?
123                .broadcast_as(input_embeds.shape())?
124                .to_dtype(DType::U32)?;
125
126            let mask_flat = special_image_mask.flatten_all()?;
127            let mut x_flat = input_embeds.flatten_all()?;
128            let src_flat = image_features.flatten_all()?;
129
130            let indices = mask_flat.nonzero()?.squeeze(1)?;
131            let current_vals = x_flat.gather(&indices, 0)?;
132            let diff = (src_flat - current_vals)?;
133            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
134
135            input_embeds = x_flat.reshape(input_embeds.shape())?;
136        };
137        self.language_model.forward_embeds(
138            input_ids,
139            input_embeds,
140            seqlen_offsets,
141            context_lens,
142            metadata,
143            flash_params,
144        )
145    }
146}
147
148impl IsqModel for Gemma3Model {
149    fn get_layers(
150        &mut self,
151    ) -> (
152        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
153        &dyn DeviceMapper,
154    ) {
155        self.language_model.get_layers()
156    }
157
158    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
159        match &self.cfg {
160            Gemma3Config::Text(_) => self.language_model.residual_tensors(),
161            Gemma3Config::WithVision { .. } => {
162                let vision_tower = self.vision_tower.as_ref().unwrap();
163                let multi_modal_projector = self.multi_modal_projector.as_ref().unwrap();
164
165                let uvb = UnVarBuilder::new();
166                uvb.pp("multi_modal_projector")
167                    .extend(multi_modal_projector.residual_tensors());
168                uvb.pp("language_model")
169                    .extend(self.language_model.residual_tensors());
170                uvb.pp("vision_tower")
171                    .pp("vision_model")
172                    .extend(vision_tower.residual_tensors());
173
174                uvb.to_safetensors()
175            }
176        }
177    }
178
179    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
180        self.language_model.imatrix_names()
181    }
182}
183
184pub struct Gemma3SpecificArgs;
185
186impl VisionModel for Gemma3Model {
187    fn forward(
188        &self,
189        input_ids: &Tensor,
190        pixel_values: Option<Tensor>,
191        seqlen_offsets: &[usize],
192        context_lens: Vec<(usize, usize)>,
193        _position_ids: Vec<usize>,
194        _model_specific_args: Box<dyn std::any::Any>,
195        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
196        flash_params: &FlashParams,
197    ) -> candle_core::Result<Tensor> {
198        self.forward(
199            input_ids,
200            pixel_values,
201            seqlen_offsets,
202            context_lens,
203            metadata,
204            flash_params,
205        )
206    }
207    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
208        Box::new(Gemma3SpecificArgs)
209    }
210    fn cache(&self) -> &EitherCache {
211        self.language_model.cache()
212    }
213    fn cache_mut(&mut self) -> &mut EitherCache {
214        self.language_model.cache_mut()
215    }
216    fn device(&self) -> &Device {
217        self.language_model.device()
218    }
219    fn max_seq_len(&self) -> usize {
220        self.language_model.max_seq_len()
221    }
222    fn config(&self) -> &ModelConfigMetadata {
223        self.language_model.config()
224    }
225    fn has_conv2d(&self) -> bool {
226        // TODO
227        false
228    }
229}
230
231impl AnyMoeBaseModelMixin for Gemma3Model {
232    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
233        self.language_model.get_mlps()
234    }
235    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
236        self.language_model.get_mlps_mut()
237    }
238    fn create_anymoe_layers(
239        &mut self,
240        additional_vbs: Vec<ShardedVarBuilder>,
241        config: AnyMoeConfig,
242        (prefix, mlp): (String, String),
243        layers: Vec<usize>,
244        expert_type: AnyMoeExpertType,
245        gate_vb: Option<ShardedVarBuilder>,
246    ) -> Result<()> {
247        self.language_model.create_anymoe_layers(
248            additional_vbs,
249            config,
250            (prefix, mlp),
251            layers,
252            expert_type,
253            gate_vb,
254        )
255    }
256    fn amoe_supported(&self) -> bool {
257        true
258    }
259}