mistralrs_core/vision_models/mistral3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::{
6    amoe::{AnyMoeBaseModelMixin, MlpLayer},
7    device_map::DeviceMapper,
8    layers::{self, Activation, RmsNorm},
9    models,
10    ops::SplitOp,
11    paged_attention::{AttentionImplementation, ModelConfigMetadata},
12    pipeline::{
13        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
14        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
15    },
16    utils::unvarbuilder::UnVarBuilder,
17    AnyMoeConfig, AnyMoeExpertType,
18};
19use candle_core::{DType, Device, Result, Tensor, D};
20use candle_nn::{Linear, Module};
21pub use config::Mistral3Config;
22pub use inputs_processor::Mistral3Processor;
23use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
24use models::mistral::Model as Mistral;
25use vision::Mistral3VisionModel;
26
27mod config;
28mod inputs_processor;
29mod vision;
30
31struct Mistral3PatchMerger {
32    merging_layer: Linear,
33    spatial_merge_size: usize,
34    patch_size: usize,
35}
36
37impl Mistral3PatchMerger {
38    fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
39        Ok(Self {
40            merging_layer: layers::linear_no_bias(
41                cfg.vision_config.hidden_size * cfg.spatial_merge_size.pow(2),
42                cfg.vision_config.hidden_size,
43                vb.pp("merging_layer"),
44            )?,
45            spatial_merge_size: cfg.spatial_merge_size,
46            patch_size: cfg.vision_config.patch_size,
47        })
48    }
49
50    fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
51        let image_sizes = image_sizes
52            .iter()
53            .map(|&(h, w)| (h as usize / self.patch_size, w as usize / self.patch_size))
54            .collect::<Vec<_>>();
55
56        let tokens_per_image = image_sizes.iter().map(|&(h, w)| h * w).collect::<Vec<_>>();
57        let d = image_features.dim(D::Minus1)?;
58
59        let mut permuted_tensor = Vec::new();
60
61        for (image_index, image_tokens) in image_features
62            .split(&tokens_per_image, 0)?
63            .iter()
64            .enumerate()
65        {
66            let (h, w) = image_sizes[image_index];
67            let image_grid = image_tokens
68                .reshape((h, w, d))?
69                .permute((2, 0, 1))?
70                .unsqueeze(0)?;
71            // Equiv of:
72            // torch.nn.functional.unfold(image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size)
73            let grid = {
74                // The first unfold extracts sliding windows along the height (dim=2),
75                // and the second unfolds the width (dim=3).
76                let patches = image_grid
77                    .unfold(2, self.spatial_merge_size, self.spatial_merge_size)?
78                    .unfold(3, self.spatial_merge_size, self.spatial_merge_size)?;
79                // patches now has shape: (N, C, n_H, n_W, K, K) where n_H = (H - K) // K + 1 and n_W = (W - K) // K + 1
80
81                let patches = patches.permute((0, 1, 4, 5, 2, 3))?;
82                patches.contiguous()?.reshape((
83                    1,
84                    d * self.spatial_merge_size * self.spatial_merge_size,
85                    (),
86                ))?
87            };
88            let grid = grid
89                .reshape((d * self.spatial_merge_size.pow(2), ()))?
90                .t()?;
91            permuted_tensor.push(grid);
92        }
93
94        let image_features = Tensor::cat(&permuted_tensor, 0)?;
95
96        self.merging_layer.forward(&image_features)
97    }
98}
99
100struct Mistral3MultiModalProjector {
101    norm: RmsNorm,
102    linear_1: Linear,
103    linear_2: Linear,
104    act: Activation,
105    patch_merger: Mistral3PatchMerger,
106}
107
108impl Mistral3MultiModalProjector {
109    fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
110        let norm = RmsNorm::new(
111            cfg.vision_config.hidden_size,
112            cfg.text_config.rms_norm_eps,
113            vb.pp("norm"),
114        )?;
115        // let num_feature_layers = match &cfg.vision_feature_layer {
116        //     Either::Left(_) => 1,
117        //     Either::Right(r) => r.len(),
118        // };
119        let num_feature_layers = 1;
120        let linear_1 = layers::linear_b(
121            cfg.vision_config.hidden_size * num_feature_layers,
122            cfg.text_config.hidden_size,
123            cfg.multimodal_projector_bias,
124            vb.pp("linear_1"),
125        )?;
126        let linear_2 = layers::linear_b(
127            cfg.text_config.hidden_size,
128            cfg.text_config.hidden_size,
129            cfg.multimodal_projector_bias,
130            vb.pp("linear_2"),
131        )?;
132        let patch_merger = Mistral3PatchMerger::new(cfg, vb.pp("patch_merger"))?;
133        Ok(Self {
134            norm,
135            linear_1,
136            linear_2,
137            act: cfg.projector_hidden_act,
138            patch_merger,
139        })
140    }
141
142    fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
143        let mut hidden_states = self.norm.forward(image_features)?;
144        hidden_states = self.patch_merger.forward(&hidden_states, image_sizes)?;
145        hidden_states = self.linear_1.forward(&hidden_states)?.apply(&self.act)?;
146        self.linear_2.forward(&hidden_states)
147    }
148
149    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
150        let uvb = UnVarBuilder::new();
151
152        uvb.pp("norm").add(&self.norm);
153        uvb.pp("linear_1").add(&self.linear_1);
154        uvb.pp("linear_2").add(&self.linear_2);
155        uvb.pp("patch_merger")
156            .pp("merging_layer")
157            .add(&self.patch_merger.merging_layer);
158
159        uvb.to_safetensors()
160    }
161}
162
163pub struct Mistral3Model {
164    text_model: Mistral,
165    vision_model: Mistral3VisionModel,
166    mmproj: Mistral3MultiModalProjector,
167    cfg: Mistral3Config,
168}
169
170impl Mistral3Model {
171    pub fn new(
172        cfg: &Mistral3Config,
173        vb: ShardedVarBuilder,
174        is_gptx: bool,
175        normal_loading_metadata: NormalLoadingMetadata,
176        attention_mechanism: AttentionImplementation,
177    ) -> Result<Self> {
178        let vision_model = Mistral3VisionModel::new(
179            &cfg.vision_config,
180            vb.pp("vision_tower"),
181            &normal_loading_metadata,
182        )?;
183        let mmproj = Mistral3MultiModalProjector::new(
184            cfg,
185            vb.pp("multi_modal_projector")
186                .set_device(normal_loading_metadata.real_device.clone()),
187        )?;
188        let text_model = Mistral::new(
189            &cfg.text_config,
190            vb.pp("language_model"),
191            is_gptx,
192            normal_loading_metadata,
193            attention_mechanism,
194        )?;
195
196        // For get_image_features, assuming this for best efficiency.
197        assert_eq!(cfg.vision_feature_layer, -1);
198
199        Ok(Self {
200            vision_model,
201            text_model,
202            mmproj,
203            cfg: cfg.clone(),
204        })
205    }
206
207    fn get_image_features(
208        &self,
209        image_features: &Tensor,
210        image_sizes: Vec<(u32, u32)>,
211    ) -> Result<Tensor> {
212        let image_outputs = self
213            .vision_model
214            .forward(image_features, image_sizes.clone())?;
215        let selected_image_feature = image_outputs;
216        self.mmproj
217            .forward(&selected_image_feature.squeeze(0)?, image_sizes)
218    }
219
220    #[allow(clippy::too_many_arguments)]
221    pub fn forward(
222        &self,
223        input_ids: &Tensor,
224        pixel_values: Option<Tensor>,
225        seqlen_offsets: &[usize],
226        context_lens: Vec<(usize, usize)>,
227        image_sizes: Option<Vec<(u32, u32)>>,
228        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
229        flash_params: &FlashParams,
230    ) -> Result<Tensor> {
231        let mut input_embeds = self.text_model.get_input_embeddings(input_ids)?;
232
233        if let Some(pixel_values) = pixel_values {
234            let special_image_mask = input_ids
235                .eq(self.cfg.image_token_index as f64)?
236                .unsqueeze(D::Minus1)?
237                .broadcast_as(input_embeds.shape().clone())?
238                .to_dtype(DType::U32)?;
239            let mask_flat = special_image_mask.flatten_all()?;
240            // Nonzero before vision model to allow async processing all the way through logits.
241            let indices = mask_flat.nonzero()?.squeeze(1)?;
242
243            let image_sizes = image_sizes.unwrap();
244            let image_features = self.get_image_features(
245                &pixel_values.to_dtype(self.vision_model.dtype())?,
246                image_sizes,
247            )?;
248
249            let mut x_flat = input_embeds.flatten_all()?;
250            let src_flat = image_features.flatten_all()?;
251
252            let current_vals = x_flat.gather(&indices, 0)?;
253            let diff = (src_flat - current_vals)?;
254            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
255
256            input_embeds = x_flat.reshape(input_embeds.shape())?;
257        }
258
259        self.text_model.forward_embeds(
260            input_ids,
261            input_embeds,
262            seqlen_offsets,
263            context_lens,
264            metadata,
265            flash_params,
266        )
267    }
268}
269
270impl IsqModel for Mistral3Model {
271    fn get_layers(
272        &mut self,
273    ) -> (
274        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
275        &dyn DeviceMapper,
276    ) {
277        let (mut tensors, mapper) = self.text_model.get_layers();
278        tensors.extend(self.vision_model.get_layers());
279        (tensors, mapper)
280    }
281
282    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
283        let uvb = UnVarBuilder::new();
284        uvb.pp("multi_modal_projector")
285            .extend(self.mmproj.residual_tensors());
286        uvb.pp("language_model")
287            .extend(self.text_model.residual_tensors());
288
289        uvb.to_safetensors()
290    }
291
292    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
293        self.text_model.imatrix_names()
294    }
295}
296
297#[derive(Default)]
298pub struct Mistral3SpecificArgs {
299    pub image_sizes: Option<Vec<(u32, u32)>>,
300}
301
302impl VisionModel for Mistral3Model {
303    fn forward(
304        &self,
305        input_ids: &Tensor,
306        pixel_values: Option<Tensor>,
307        seqlen_offsets: &[usize],
308        context_lens: Vec<(usize, usize)>,
309        _position_ids: Vec<usize>,
310        model_specific_args: Box<dyn std::any::Any>,
311        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
312        flash_params: &FlashParams,
313    ) -> candle_core::Result<Tensor> {
314        let Mistral3SpecificArgs { image_sizes } = *model_specific_args
315            .downcast()
316            .expect("Cannot downcast into `Mistral3SpecificArgs`");
317        self.forward(
318            input_ids,
319            pixel_values,
320            seqlen_offsets,
321            context_lens,
322            image_sizes,
323            metadata,
324            flash_params,
325        )
326    }
327    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
328        Box::new(Mistral3SpecificArgs::default())
329    }
330    fn cache(&self) -> &EitherCache {
331        self.text_model.cache()
332    }
333    fn cache_mut(&mut self) -> &mut EitherCache {
334        self.text_model.cache_mut()
335    }
336    fn device(&self) -> &Device {
337        self.text_model.device()
338    }
339    fn max_seq_len(&self) -> usize {
340        self.text_model.max_seq_len()
341    }
342    fn config(&self) -> &ModelConfigMetadata {
343        self.text_model.config()
344    }
345    fn has_conv2d(&self) -> bool {
346        true
347    }
348}
349
350impl AnyMoeBaseModelMixin for Mistral3Model {
351    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
352        self.text_model.get_mlps()
353    }
354    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
355        self.text_model.get_mlps_mut()
356    }
357    fn create_anymoe_layers(
358        &mut self,
359        additional_vbs: Vec<ShardedVarBuilder>,
360        config: AnyMoeConfig,
361        (prefix, mlp): (String, String),
362        layers: Vec<usize>,
363        expert_type: AnyMoeExpertType,
364        gate_vb: Option<ShardedVarBuilder>,
365    ) -> Result<()> {
366        self.text_model.create_anymoe_layers(
367            additional_vbs,
368            config,
369            (prefix, mlp),
370            layers,
371            expert_type,
372            gate_vb,
373        )
374    }
375    fn amoe_supported(&self) -> bool {
376        true
377    }
378}