mistralrs_core/vision_models/mistral3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::{
6    amoe::{AnyMoeBaseModelMixin, MlpLayer},
7    device_map::DeviceMapper,
8    layers::{self, Activation, RmsNorm},
9    models,
10    ops::{NonZeroOp, SplitOp},
11    paged_attention::{AttentionImplementation, ModelConfigMetadata},
12    pipeline::{
13        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
14        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
15    },
16    utils::unvarbuilder::UnVarBuilder,
17    AnyMoeConfig, AnyMoeExpertType,
18};
19use candle_core::{DType, Device, Result, Tensor, D};
20use candle_nn::{Linear, Module};
21pub use config::Mistral3Config;
22pub use inputs_processor::Mistral3Processor;
23use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
24use models::mistral::Model as Mistral;
25use vision::Mistral3VisionModel;
26
27mod config;
28mod inputs_processor;
29mod vision;
30
31struct Mistral3PatchMerger {
32    merging_layer: Linear,
33    spatial_merge_size: usize,
34    patch_size: usize,
35}
36
37impl Mistral3PatchMerger {
38    fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
39        Ok(Self {
40            merging_layer: layers::linear_no_bias(
41                cfg.vision_config.hidden_size * cfg.spatial_merge_size.pow(2),
42                cfg.vision_config.hidden_size,
43                vb.pp("merging_layer"),
44            )?,
45            spatial_merge_size: cfg.spatial_merge_size,
46            patch_size: cfg.vision_config.patch_size,
47        })
48    }
49
50    fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
51        let image_sizes = image_sizes
52            .iter()
53            .map(|&(h, w)| (h as usize / self.patch_size, w as usize / self.patch_size))
54            .collect::<Vec<_>>();
55
56        let tokens_per_image = image_sizes.iter().map(|&(h, w)| h * w).collect::<Vec<_>>();
57        let d = image_features.dim(D::Minus1)?;
58
59        let mut permuted_tensor = Vec::new();
60
61        for (image_index, image_tokens) in image_features
62            .split(&tokens_per_image, 0)?
63            .iter()
64            .enumerate()
65        {
66            let (h, w) = image_sizes[image_index];
67            let image_grid = image_tokens
68                .reshape((h, w, d))?
69                .permute((2, 0, 1))?
70                .unsqueeze(0)?;
71            // Equiv of:
72            // torch.nn.functional.unfold(image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size)
73            let grid = {
74                // The first unfold extracts sliding windows along the height (dim=2),
75                // and the second unfolds the width (dim=3).
76                let patches = image_grid
77                    .unfold(2, self.spatial_merge_size, self.spatial_merge_size)?
78                    .unfold(3, self.spatial_merge_size, self.spatial_merge_size)?;
79                // patches now has shape: (N, C, n_H, n_W, K, K) where n_H = (H - K) // K + 1 and n_W = (W - K) // K + 1
80
81                let patches = patches.permute((0, 1, 4, 5, 2, 3))?;
82                patches.contiguous()?.reshape((
83                    1,
84                    d * self.spatial_merge_size * self.spatial_merge_size,
85                    (),
86                ))?
87            };
88            let grid = grid
89                .reshape((d * self.spatial_merge_size.pow(2), ()))?
90                .t()?;
91            permuted_tensor.push(grid);
92        }
93
94        let image_features = Tensor::cat(&permuted_tensor, 0)?;
95
96        self.merging_layer.forward(&image_features)
97    }
98}
99
100struct Mistral3MultiModalProjector {
101    norm: RmsNorm,
102    linear_1: Linear,
103    linear_2: Linear,
104    act: Activation,
105    patch_merger: Mistral3PatchMerger,
106}
107
108impl Mistral3MultiModalProjector {
109    fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
110        let norm = RmsNorm::new(
111            cfg.vision_config.hidden_size,
112            cfg.text_config.rms_norm_eps,
113            vb.pp("norm"),
114        )?;
115        // let num_feature_layers = match &cfg.vision_feature_layer {
116        //     Either::Left(_) => 1,
117        //     Either::Right(r) => r.len(),
118        // };
119        let num_feature_layers = 1;
120        let linear_1 = layers::linear_b(
121            cfg.vision_config.hidden_size * num_feature_layers,
122            cfg.text_config.hidden_size,
123            cfg.multimodal_projector_bias,
124            vb.pp("linear_1"),
125        )?;
126        let linear_2 = layers::linear_b(
127            cfg.text_config.hidden_size,
128            cfg.text_config.hidden_size,
129            cfg.multimodal_projector_bias,
130            vb.pp("linear_2"),
131        )?;
132        let patch_merger = Mistral3PatchMerger::new(cfg, vb.pp("patch_merger"))?;
133        Ok(Self {
134            norm,
135            linear_1,
136            linear_2,
137            act: cfg.projector_hidden_act,
138            patch_merger,
139        })
140    }
141
142    fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
143        let mut hidden_states = self.norm.forward(image_features)?;
144        hidden_states = self.patch_merger.forward(&hidden_states, image_sizes)?;
145        hidden_states = self.linear_1.forward(&hidden_states)?.apply(&self.act)?;
146        self.linear_2.forward(&hidden_states)
147    }
148
149    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
150        let uvb = UnVarBuilder::new();
151
152        uvb.pp("norm").add(&self.norm);
153        uvb.pp("linear_1").add(&self.linear_1);
154        uvb.pp("linear_2").add(&self.linear_2);
155        uvb.pp("patch_merger")
156            .pp("merging_layer")
157            .add(&self.patch_merger.merging_layer);
158
159        uvb.to_safetensors()
160    }
161}
162
163pub struct Mistral3Model {
164    text_model: Mistral,
165    vision_model: Mistral3VisionModel,
166    mmproj: Mistral3MultiModalProjector,
167    cfg: Mistral3Config,
168}
169
170impl Mistral3Model {
171    pub fn new(
172        cfg: &Mistral3Config,
173        vb: ShardedVarBuilder,
174        is_gptx: bool,
175        normal_loading_metadata: NormalLoadingMetadata,
176        attention_mechanism: AttentionImplementation,
177    ) -> Result<Self> {
178        let vision_model = Mistral3VisionModel::new(
179            &cfg.vision_config,
180            vb.pp("vision_tower"),
181            &normal_loading_metadata,
182        )?;
183        let mmproj = Mistral3MultiModalProjector::new(
184            cfg,
185            vb.pp("multi_modal_projector")
186                .set_device(normal_loading_metadata.real_device.clone()),
187        )?;
188        let text_model = Mistral::new(
189            &cfg.text_config,
190            vb.pp("language_model"),
191            is_gptx,
192            normal_loading_metadata,
193            attention_mechanism,
194        )?;
195
196        // For get_image_features, assuming this for best efficiency.
197        assert_eq!(cfg.vision_feature_layer, -1);
198
199        Ok(Self {
200            vision_model,
201            text_model,
202            mmproj,
203            cfg: cfg.clone(),
204        })
205    }
206
207    fn get_image_features(
208        &self,
209        image_features: &Tensor,
210        image_sizes: Vec<(u32, u32)>,
211    ) -> Result<Tensor> {
212        let image_outputs = self
213            .vision_model
214            .forward(image_features, image_sizes.clone())?;
215        let selected_image_feature = image_outputs;
216        self.mmproj
217            .forward(&selected_image_feature.squeeze(0)?, image_sizes)
218    }
219
220    #[allow(clippy::too_many_arguments)]
221    pub fn forward(
222        &self,
223        input_ids: &Tensor,
224        pixel_values: Option<Tensor>,
225        seqlen_offsets: &[usize],
226        context_lens: Vec<(usize, usize)>,
227        image_sizes: Option<Vec<(u32, u32)>>,
228        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
229        flash_params: &FlashParams,
230    ) -> Result<Tensor> {
231        let mut input_embeds = self.text_model.get_input_embeddings(input_ids)?;
232
233        if let Some(pixel_values) = pixel_values {
234            let image_sizes = image_sizes.unwrap();
235            let image_features = self.get_image_features(
236                &pixel_values.to_dtype(self.vision_model.dtype())?,
237                image_sizes,
238            )?;
239
240            let special_image_mask = input_ids
241                .eq(self.cfg.image_token_index as f64)?
242                .unsqueeze(D::Minus1)?
243                .broadcast_as(input_embeds.shape().clone())?
244                .to_dtype(DType::U32)?;
245
246            let mask_flat = special_image_mask.flatten_all()?;
247            let mut x_flat = input_embeds.flatten_all()?;
248            let src_flat = image_features.flatten_all()?;
249
250            let indices = mask_flat.nonzero()?.squeeze(1)?;
251            let current_vals = x_flat.gather(&indices, 0)?;
252            let diff = (src_flat - current_vals)?;
253            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
254
255            input_embeds = x_flat.reshape(input_embeds.shape())?;
256        }
257
258        self.text_model.forward_embeds(
259            input_ids,
260            input_embeds,
261            seqlen_offsets,
262            context_lens,
263            metadata,
264            flash_params,
265        )
266    }
267}
268
269impl IsqModel for Mistral3Model {
270    fn get_layers(
271        &mut self,
272    ) -> (
273        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
274        &dyn DeviceMapper,
275    ) {
276        let (mut tensors, mapper) = self.text_model.get_layers();
277        tensors.extend(self.vision_model.get_layers());
278        (tensors, mapper)
279    }
280
281    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
282        let uvb = UnVarBuilder::new();
283        uvb.pp("multi_modal_projector")
284            .extend(self.mmproj.residual_tensors());
285        uvb.pp("language_model")
286            .extend(self.text_model.residual_tensors());
287
288        uvb.to_safetensors()
289    }
290
291    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
292        self.text_model.imatrix_names()
293    }
294}
295
296#[derive(Default)]
297pub struct Mistral3SpecificArgs {
298    pub image_sizes: Option<Vec<(u32, u32)>>,
299}
300
301impl VisionModel for Mistral3Model {
302    fn forward(
303        &self,
304        input_ids: &Tensor,
305        pixel_values: Option<Tensor>,
306        seqlen_offsets: &[usize],
307        context_lens: Vec<(usize, usize)>,
308        _position_ids: Vec<usize>,
309        model_specific_args: Box<dyn std::any::Any>,
310        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
311        flash_params: &FlashParams,
312    ) -> candle_core::Result<Tensor> {
313        let Mistral3SpecificArgs { image_sizes } = *model_specific_args
314            .downcast()
315            .expect("Cannot downcast into `Mistral3SpecificArgs`");
316        self.forward(
317            input_ids,
318            pixel_values,
319            seqlen_offsets,
320            context_lens,
321            image_sizes,
322            metadata,
323            flash_params,
324        )
325    }
326    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
327        Box::new(Mistral3SpecificArgs::default())
328    }
329    fn cache(&self) -> &EitherCache {
330        self.text_model.cache()
331    }
332    fn cache_mut(&mut self) -> &mut EitherCache {
333        self.text_model.cache_mut()
334    }
335    fn device(&self) -> &Device {
336        self.text_model.device()
337    }
338    fn max_seq_len(&self) -> usize {
339        self.text_model.max_seq_len()
340    }
341    fn config(&self) -> &ModelConfigMetadata {
342        self.text_model.config()
343    }
344    fn has_conv2d(&self) -> bool {
345        true
346    }
347}
348
349impl AnyMoeBaseModelMixin for Mistral3Model {
350    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
351        self.text_model.get_mlps()
352    }
353    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
354        self.text_model.get_mlps_mut()
355    }
356    fn create_anymoe_layers(
357        &mut self,
358        additional_vbs: Vec<ShardedVarBuilder>,
359        config: AnyMoeConfig,
360        (prefix, mlp): (String, String),
361        layers: Vec<usize>,
362        expert_type: AnyMoeExpertType,
363        gate_vb: Option<ShardedVarBuilder>,
364    ) -> Result<()> {
365        self.text_model.create_anymoe_layers(
366            additional_vbs,
367            config,
368            (prefix, mlp),
369            layers,
370            expert_type,
371            gate_vb,
372        )
373    }
374    fn amoe_supported(&self) -> bool {
375        true
376    }
377}