mistralrs_core/vision_models/mistral3/
vision.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
4use mistralrs_quant::{linear_b, QuantMethod, ShardedVarBuilder};
5
6use crate::{
7    layers::{self, GetFloatInfo, RmsNorm},
8    pipeline::NormalLoadingMetadata,
9};
10
11fn default_act() -> candle_nn::Activation {
12    candle_nn::Activation::Silu
13}
14
15fn default_hidden_size() -> usize {
16    1024
17}
18
19fn default_intermediate_size() -> usize {
20    4096
21}
22
23fn default_num_channels() -> usize {
24    3
25}
26
27fn default_num_hidden_layers() -> usize {
28    24
29}
30
31fn default_num_attention_heads() -> usize {
32    16
33}
34
35#[derive(serde::Deserialize, Debug, Clone)]
36pub struct Mistral3VisionConfig {
37    #[serde(default = "default_hidden_size")]
38    pub hidden_size: usize,
39    #[serde(default = "default_num_channels")]
40    pub num_channels: usize,
41    pub image_size: usize,
42    pub patch_size: usize,
43    pub rope_theta: f64,
44    #[serde(default = "default_intermediate_size")]
45    pub intermediate_size: usize,
46    #[serde(default = "default_num_hidden_layers")]
47    pub num_hidden_layers: usize,
48    pub head_dim: Option<usize>,
49    #[serde(default = "default_num_attention_heads")]
50    pub num_attention_heads: usize,
51    #[serde(default = "default_act")]
52    pub hidden_act: candle_nn::Activation,
53}
54
55impl Mistral3VisionConfig {
56    fn head_dim(&self) -> usize {
57        self.head_dim
58            .unwrap_or(self.hidden_size / self.num_attention_heads)
59    }
60}
61
62#[derive(Debug, Clone)]
63struct Attention {
64    q_proj: Arc<dyn QuantMethod>,
65    k_proj: Arc<dyn QuantMethod>,
66    v_proj: Arc<dyn QuantMethod>,
67    o_proj: Arc<dyn QuantMethod>,
68    scale: f64,
69    num_heads: usize,
70    head_dim: usize,
71}
72
73impl Attention {
74    fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
75        let h = cfg.hidden_size;
76        let num_heads = cfg.num_attention_heads;
77        let head_dim = cfg.head_dim();
78        let q_proj = linear_b(h, h, false, &None, vb.pp("q_proj"))?;
79        let k_proj = linear_b(h, h, false, &None, vb.pp("k_proj"))?;
80        let v_proj = linear_b(h, h, false, &None, vb.pp("v_proj"))?;
81        let o_proj = linear_b(h, h, false, &None, vb.pp("o_proj"))?;
82        let scale = (head_dim as f64).powf(-0.5);
83        Ok(Self {
84            q_proj,
85            k_proj,
86            v_proj,
87            o_proj,
88            scale,
89            num_heads,
90            head_dim,
91        })
92    }
93
94    fn forward(
95        &self,
96        xs: &Tensor,
97        emb: &RotaryEmbedding,
98        subsampled_positions: Option<&Tensor>,
99        attention_mask: Option<&Tensor>,
100    ) -> Result<Tensor> {
101        let (b, patches, _) = xs.dims3()?;
102        let query_states = self.q_proj.forward_autocast(xs)?;
103        let key_states = self.k_proj.forward_autocast(xs)?;
104        let value_states = self.v_proj.forward_autocast(xs)?;
105
106        let shape = (b, patches, self.num_heads, self.head_dim);
107        let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
108        let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
109        let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
110
111        let (query_states, key_states) =
112            emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
113        let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
114
115        let attn_weights = match attention_mask {
116            None => attn_weights,
117            Some(mask) => attn_weights.broadcast_add(mask)?,
118        };
119
120        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
121
122        self.o_proj.forward_autocast(
123            &attn_weights
124                .matmul(&value_states)?
125                .transpose(1, 2)?
126                .reshape((b, patches, ()))?,
127        )
128    }
129}
130
131#[derive(Debug, Clone)]
132struct Mlp {
133    gate_proj: Arc<dyn QuantMethod>,
134    up_proj: Arc<dyn QuantMethod>,
135    down_proj: Arc<dyn QuantMethod>,
136    act_fn: candle_nn::Activation,
137}
138
139impl Mlp {
140    fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
141        let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
142        let gate_proj = linear_b(h, i, false, &None, vb.pp("gate_proj"))?;
143        let up_proj = linear_b(h, i, false, &None, vb.pp("up_proj"))?;
144        let down_proj = linear_b(i, h, false, &None, vb.pp("down_proj"))?;
145        Ok(Self {
146            gate_proj,
147            up_proj,
148            down_proj,
149            act_fn: cfg.hidden_act,
150        })
151    }
152}
153
154impl Module for Mlp {
155    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
156        self.down_proj.forward_autocast(
157            &(self.gate_proj.forward_autocast(xs)?.apply(&self.act_fn)?
158                * self.up_proj.forward_autocast(xs)?)?,
159        )
160    }
161}
162
163#[derive(Debug, Clone)]
164struct AttentionLayer {
165    attention_norm: RmsNorm,
166    feed_forward: Mlp,
167    attention: Attention,
168    ffn_norm: RmsNorm,
169}
170
171impl AttentionLayer {
172    fn new(
173        cfg: &Mistral3VisionConfig,
174        vb: ShardedVarBuilder,
175        normal_loading_metadata: &NormalLoadingMetadata,
176    ) -> Result<Self> {
177        let attention_norm = RmsNorm::new(
178            cfg.hidden_size,
179            1e-5,
180            vb.pp("attention_norm")
181                .set_device(normal_loading_metadata.real_device.clone()),
182        )?;
183        let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
184        let attention = Attention::new(cfg, vb.pp("attention"))?;
185        let ffn_norm = RmsNorm::new(
186            cfg.hidden_size,
187            1e-5,
188            vb.pp("ffn_norm")
189                .set_device(normal_loading_metadata.real_device.clone()),
190        )?;
191        Ok(Self {
192            attention_norm,
193            feed_forward,
194            attention,
195            ffn_norm,
196        })
197    }
198
199    fn forward(
200        &self,
201        xs: &Tensor,
202        emb: &RotaryEmbedding,
203        subsampled_positions: Option<&Tensor>,
204        attention_mask: Option<&Tensor>,
205    ) -> Result<Tensor> {
206        let residual = xs;
207        let xs = self.attention.forward(
208            &xs.apply(&self.attention_norm)?,
209            emb,
210            subsampled_positions,
211            attention_mask,
212        )?;
213        let xs = (residual + xs)?;
214        let residual = &xs;
215        let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
216        xs + residual
217    }
218}
219
220#[derive(Debug, Clone)]
221struct Transformer {
222    layers: Vec<AttentionLayer>,
223}
224
225impl Transformer {
226    fn new(
227        cfg: &Mistral3VisionConfig,
228        vb: ShardedVarBuilder,
229        normal_loading_metadata: &NormalLoadingMetadata,
230    ) -> Result<Self> {
231        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
232        let vb = vb.pp("layers");
233        for layer_idx in 0..cfg.num_hidden_layers {
234            let layer = AttentionLayer::new(cfg, vb.pp(layer_idx), normal_loading_metadata)?;
235            layers.push(layer)
236        }
237        Ok(Self { layers })
238    }
239
240    fn forward(
241        &self,
242        xs: &Tensor,
243        emb: &RotaryEmbedding,
244        subsampled_positions: Option<&Tensor>,
245        attention_mask: Option<&Tensor>,
246    ) -> Result<Tensor> {
247        let mut xs = xs.clone();
248        for layer in self.layers.iter() {
249            xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
250        }
251        Ok(xs)
252    }
253}
254
255#[derive(Debug, Clone)]
256struct RotaryEmbedding {
257    cos: Tensor,
258    sin: Tensor,
259}
260
261impl RotaryEmbedding {
262    fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
263        let dtype = vb.dtype();
264        let dev = vb.device();
265        let dim = cfg.head_dim();
266        let rope_theta = cfg.rope_theta as f32;
267        let max_patches_per_side = cfg.image_size / cfg.patch_size;
268        let freqs: Vec<_> = (0..dim)
269            .step_by(2)
270            .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
271            .collect();
272        let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
273        let freqs_h = Tensor::new(freqs_h, dev)?;
274        let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
275        let freqs_w = Tensor::new(freqs_w, dev)?;
276        let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
277        let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
278        let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
279        let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
280        let inv_freq = Tensor::cat(
281            &[
282                freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
283                freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
284            ],
285            D::Minus1,
286        )?
287        .reshape(((), dim / 2))?;
288        let cos = inv_freq.cos()?.to_dtype(dtype)?;
289        let sin = inv_freq.sin()?.to_dtype(dtype)?;
290        Ok(Self { cos, sin })
291    }
292
293    fn apply_rotary_emb_qkv(
294        &self,
295        q: &Tensor,
296        k: &Tensor,
297        subsampled_positions: Option<&Tensor>,
298    ) -> Result<(Tensor, Tensor)> {
299        let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
300        let (cos, sin) = match subsampled_positions {
301            None => (&self.cos, &self.sin),
302            Some(pos) => (
303                &self.cos.index_select(pos, 0)?,
304                &self.sin.index_select(pos, 0)?,
305            ),
306        };
307        let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
308        let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
309        Ok((q_embed, k_embed))
310    }
311}
312
313#[derive(Debug, Clone)]
314pub struct Mistral3VisionModel {
315    patch_conv: candle_nn::Conv2d,
316    ln_pre: RmsNorm,
317    transformer: Transformer,
318    patch_positional_embedding: RotaryEmbedding,
319    max_image_width: u32,
320    patch_size: usize,
321    dtype: DType,
322}
323
324impl Mistral3VisionModel {
325    pub fn new(
326        cfg: &Mistral3VisionConfig,
327        vb: ShardedVarBuilder,
328        normal_loading_metadata: &NormalLoadingMetadata,
329    ) -> Result<Self> {
330        let conv2d_cfg = candle_nn::Conv2dConfig {
331            stride: cfg.patch_size,
332            ..Default::default()
333        };
334        let patch_conv = layers::conv2d_no_bias(
335            cfg.num_channels,
336            cfg.hidden_size,
337            cfg.patch_size,
338            conv2d_cfg,
339            vb.pp("patch_conv")
340                .set_device(normal_loading_metadata.real_device.clone()),
341        )?;
342        let ln_pre = RmsNorm::new(
343            cfg.hidden_size,
344            1e-5,
345            vb.pp("ln_pre")
346                .set_device(normal_loading_metadata.real_device.clone()),
347        )?;
348        let transformer = Transformer::new(cfg, vb.pp("transformer"), normal_loading_metadata)?;
349        let patch_positional_embedding = RotaryEmbedding::new(
350            cfg,
351            vb.pp("patch_positional_embedding")
352                .set_device(normal_loading_metadata.real_device.clone()),
353        )?;
354        let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
355        Ok(Self {
356            patch_conv,
357            ln_pre,
358            transformer,
359            patch_positional_embedding,
360            max_image_width,
361            patch_size: cfg.patch_size,
362            dtype: vb.dtype(),
363        })
364    }
365
366    fn position_ids_in_meshgrid(
367        &self,
368        patch_embeds_list: &Vec<Tensor>,
369        device: &Device,
370    ) -> Result<Tensor> {
371        let mut positions = Vec::new();
372        for patch in patch_embeds_list {
373            let (height, width) = (patch.dim(D::Minus2)?, patch.dim(D::Minus1)?);
374            let idx = Tensor::arange(0, height as u32, device)?;
375            let idy = Tensor::arange(0, width as u32, device)?;
376            let mesh = Tensor::meshgrid(&[idx, idy], false)?;
377            let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
378            positions.push(ids);
379        }
380        Tensor::cat(&positions, 0)
381    }
382
383    fn generate_block_attention_mask(
384        &self,
385        patch_embeds_list: Vec<usize>,
386        patch_embeds: &Tensor,
387    ) -> Result<Tensor> {
388        let seq_len = patch_embeds.dim(1)?;
389        let mut causal_mask = (Tensor::ones(
390            (seq_len, seq_len),
391            patch_embeds.dtype(),
392            patch_embeds.device(),
393        )? * patch_embeds.dtype().finfo()?.min)?;
394
395        let block_end_idx: Vec<usize> = patch_embeds_list.iter().fold(Vec::new(), |mut acc, &x| {
396            let new_sum = x + acc.last().copied().unwrap_or(0);
397            acc.push(new_sum);
398            acc
399        });
400        let block_start_idx: Vec<usize> = {
401            let mut extended = vec![0];
402            extended.extend_from_slice(&patch_embeds_list[..patch_embeds_list.len() - 1]);
403            extended.into_iter().fold(Vec::new(), |mut acc, x| {
404                let new_sum = x + acc.last().copied().unwrap_or(0);
405                acc.push(new_sum);
406                acc
407            })
408        };
409        for (start, end) in block_start_idx.into_iter().zip(block_end_idx) {
410            causal_mask = causal_mask.slice_assign(
411                &[&(start..end), &(start..end)],
412                &Tensor::zeros(
413                    (end - start, end - start),
414                    causal_mask.dtype(),
415                    causal_mask.device(),
416                )?,
417            )?;
418        }
419
420        causal_mask
421            .reshape((1, 1, causal_mask.dim(0)?, causal_mask.dim(1)?))?
422            .repeat((patch_embeds.dim(0)?, 1, 1, 1))
423    }
424
425    pub fn forward(&self, xs: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
426        let patch_embeds = xs.apply(&self.patch_conv)?;
427        let patch_embeds_list = image_sizes
428            .iter()
429            .enumerate()
430            .map(|(i, &size)| {
431                patch_embeds
432                    .i(i)?
433                    .narrow(D::Minus2, 0, size.0 as usize / self.patch_size)?
434                    .narrow(D::Minus1, 0, size.1 as usize / self.patch_size)
435            })
436            .collect::<Result<Vec<Tensor>>>()?;
437        let patch_embeds = Tensor::cat(
438            &patch_embeds_list
439                .iter()
440                .map(|p| p.flatten_from(1)?.t())
441                .collect::<Result<Vec<Tensor>>>()?,
442            0,
443        )?
444        .unsqueeze(0)?;
445        let patch_embeds = patch_embeds.apply(&self.ln_pre)?;
446
447        let subsampled_positions =
448            Some(self.position_ids_in_meshgrid(&patch_embeds_list, patch_embeds.device())?);
449
450        let attention_mask = self.generate_block_attention_mask(
451            patch_embeds_list
452                .iter()
453                .map(|p| Ok(p.dim(D::Minus2)? * p.dim(D::Minus1)?))
454                .collect::<Result<Vec<usize>>>()?,
455            &patch_embeds,
456        )?;
457
458        self.transformer.forward(
459            &patch_embeds,
460            &self.patch_positional_embedding,
461            subsampled_positions.as_ref(),
462            Some(&attention_mask),
463        )
464    }
465
466    pub fn dtype(&self) -> DType {
467        self.dtype
468    }
469
470    pub fn get_layers(&mut self) -> Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)> {
471        let mut tensors = Vec::new();
472        for layer in &mut self.transformer.layers {
473            tensors.push((&mut layer.attention.q_proj, None));
474            tensors.push((&mut layer.attention.k_proj, None));
475            tensors.push((&mut layer.attention.v_proj, None));
476            tensors.push((&mut layer.attention.o_proj, None));
477
478            tensors.push((&mut layer.feed_forward.gate_proj, None));
479            tensors.push((&mut layer.feed_forward.up_proj, None));
480            tensors.push((&mut layer.feed_forward.down_proj, None));
481        }
482        tensors
483    }
484}