mistralrs_core/vision_models/qwen2vl/
vision.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{LayerNorm, Linear, Module};
5use mistralrs_quant::{ColumnParallelLayer, QuantMethod, ShardedVarBuilder};
6
7use crate::{
8    layers::{self, layer_norm, Activation, Conv3dConfig, Conv3dNoBias, MatMul},
9    ops::RepeatInterleaveOp,
10};
11
12use super::config::VisionConfig;
13
14struct PatchEmbed {
15    proj: Conv3dNoBias,
16    in_channels: usize,
17    patch_size: usize,
18    temporal_patch_size: usize,
19    embed_dim: usize,
20}
21
22// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L272
23impl PatchEmbed {
24    fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
25        if cfg.temporal_patch_size != 2 {
26            candle_core::bail!("Only support temporal patch size of 2");
27        }
28        Ok(Self {
29            proj: Conv3dNoBias::new(
30                cfg.in_channels,
31                cfg.embed_dim,
32                [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size],
33                Conv3dConfig {
34                    stride: cfg.patch_size,
35                    ..Default::default()
36                },
37                vb.pp("proj"),
38            )?,
39            in_channels: cfg.in_channels,
40            patch_size: cfg.patch_size,
41            temporal_patch_size: cfg.temporal_patch_size,
42            embed_dim: cfg.embed_dim,
43        })
44    }
45
46    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
47        let xs = xs.reshape((
48            (),
49            self.in_channels,
50            self.temporal_patch_size,
51            self.patch_size,
52            self.patch_size,
53        ))?;
54        xs.apply(&self.proj)?.reshape(((), self.embed_dim))
55    }
56}
57
58// https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L314
59struct VisionMlp {
60    fc1: Arc<dyn QuantMethod>,
61    fc2: Arc<dyn QuantMethod>,
62    act: Activation,
63}
64
65impl VisionMlp {
66    fn new(
67        dim: usize,
68        hidden_dim: usize,
69        act: Activation,
70        vb: ShardedVarBuilder,
71        comm: &Arc<mistralrs_quant::Comm>,
72    ) -> Result<Self> {
73        Ok(Self {
74            fc1: ColumnParallelLayer::new(dim, hidden_dim, &None, true, comm, vb.pp("fc1"))?,
75            fc2: ColumnParallelLayer::new(hidden_dim, dim, &None, true, comm, vb.pp("fc2"))?,
76            act,
77        })
78    }
79
80    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
81        let fc1 = self.act.forward(&self.fc1.forward(&xs.unsqueeze(0)?)?)?;
82        self.fc2.forward(&fc1)?.squeeze(0)
83    }
84}
85
86fn rotate_half(xs: &Tensor) -> Result<Tensor> {
87    let last_dim = xs.dim(D::Minus1)?;
88    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
89    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
90    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
91}
92
93fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
94    let cos = freqs.cos()?;
95    let sin = freqs.sin()?;
96
97    xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
98}
99
100// https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325
101struct VisionAttention {
102    qkv: Arc<dyn QuantMethod>,
103    proj: Arc<dyn QuantMethod>,
104    num_heads: usize,
105    head_dim: usize,
106}
107
108impl VisionAttention {
109    fn new(dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
110        Ok(Self {
111            qkv: mistralrs_quant::linear(dim, dim * 3, &None, vb.pp("qkv"))?,
112            proj: mistralrs_quant::linear(dim, dim, &None, vb.pp("proj"))?,
113            num_heads,
114            head_dim: dim / num_heads,
115        })
116    }
117    fn forward(
118        &self,
119        xs: &Tensor,
120        attention_mask: Option<&Tensor>,
121        rotary_pos_emb: &Tensor,
122    ) -> Result<Tensor> {
123        let seq_len = xs.dim(0)?;
124        let (mut q, mut k, mut v) = {
125            let qkv = self
126                .qkv
127                .forward(&xs.unsqueeze(0)?)?
128                .reshape((seq_len, 3, self.num_heads, ()))?
129                .permute((1, 0, 2, 3))?
130                .chunk(3, 0)?;
131            (qkv[0].squeeze(0)?, qkv[1].squeeze(0)?, qkv[2].squeeze(0)?)
132        };
133
134        q = apply_rotary_pos_emb_vision(&q.unsqueeze(0)?, rotary_pos_emb)?
135            .squeeze(0)?
136            .to_dtype(q.dtype())?;
137        k = apply_rotary_pos_emb_vision(&k.unsqueeze(0)?, rotary_pos_emb)?
138            .squeeze(0)?
139            .to_dtype(q.dtype())?;
140
141        q = q.transpose(0, 1)?.contiguous()?;
142        k = k.transpose(0, 1)?.contiguous()?;
143        v = v.transpose(0, 1)?.contiguous()?;
144
145        let att = {
146            let mut att =
147                (MatMul.matmul(&q, &k.transpose(1, 2)?)? / (self.head_dim as f64).sqrt())?;
148            att = match attention_mask {
149                Some(m) => att.broadcast_add(m)?,
150                None => att,
151            };
152            att = candle_nn::ops::softmax_last_dim(&att)?;
153            MatMul
154                .matmul(&att, &v)?
155                .transpose(0, 1)?
156                .reshape((seq_len, ()))?
157                .to_dtype(xs.dtype())?
158        };
159
160        self.proj.forward(&att.unsqueeze(0)?)?.squeeze(0)
161    }
162}
163
164// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418
165struct VisionBlock {
166    norm1: LayerNorm,
167    norm2: LayerNorm,
168    mlp: VisionMlp,
169    attn: VisionAttention,
170}
171
172impl VisionBlock {
173    fn new(
174        cfg: &VisionConfig,
175        vb: ShardedVarBuilder,
176        comm: &Arc<mistralrs_quant::Comm>,
177    ) -> Result<Self> {
178        let norm1 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm1"))?;
179        let norm2 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm2"))?;
180
181        let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
182        let mlp = VisionMlp::new(
183            cfg.embed_dim,
184            mlp_hidden_dim,
185            cfg.hidden_act,
186            vb.pp("mlp"),
187            comm,
188        )?;
189        let attn = VisionAttention::new(cfg.embed_dim, cfg.num_heads, vb.pp("attn"))?;
190
191        Ok(Self {
192            norm1,
193            norm2,
194            mlp,
195            attn,
196        })
197    }
198
199    fn forward(
200        &self,
201        xs: &Tensor,
202        attention_mask: Option<&Tensor>,
203        rotary_pos_emb: &Tensor,
204    ) -> Result<Tensor> {
205        let xs = (xs
206            + self
207                .attn
208                .forward(&self.norm1.forward(xs)?, attention_mask, rotary_pos_emb)?)?;
209        &xs + self.mlp.forward(&self.norm2.forward(&xs)?)?
210    }
211}
212
213struct PatchMerger {
214    ln_q: LayerNorm,
215    mlp0: Linear,
216    mlp2: Linear,
217    hidden_size: usize,
218}
219
220impl PatchMerger {
221    pub fn new(
222        dim: usize,
223        context_dim: usize,
224        spatial_merge_size: usize,
225        vb: ShardedVarBuilder,
226    ) -> Result<Self> {
227        let hidden_size = context_dim * spatial_merge_size.pow(2);
228        let mlp0 = layers::linear(hidden_size, hidden_size, vb.pp("mlp.0"))?;
229        let mlp2 = layers::linear(hidden_size, dim, vb.pp("mlp.2"))?;
230        Ok(Self {
231            ln_q: layer_norm(context_dim, 1e-6, vb.pp("ln_q"))?,
232            mlp0,
233            mlp2,
234            hidden_size,
235        })
236    }
237
238    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
239        xs.unsqueeze(0)?
240            .apply(&self.ln_q)?
241            .reshape(((), self.hidden_size))?
242            .apply(&self.mlp0)?
243            .gelu()?
244            .apply(&self.mlp2)?
245            .squeeze(0)
246    }
247}
248
249struct VisionRotaryEmbedding {
250    inv_freq: Tensor,
251}
252
253impl VisionRotaryEmbedding {
254    const THETA: f32 = 10000.;
255
256    fn new(dim: usize, device: &Device) -> Result<Self> {
257        let inv_freq = (0..dim)
258            .step_by(2)
259            .map(|i| 1f32 / Self::THETA.powf(i as f32 / dim as f32))
260            .collect::<Vec<_>>();
261        let inv_freq_len = inv_freq.len();
262        Ok(Self {
263            inv_freq: Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?,
264        })
265    }
266
267    fn make_embeds(&self, seqlen: usize) -> Result<Tensor> {
268        let seq =
269            Tensor::arange(0f32, seqlen as f32, self.inv_freq.device())?.unsqueeze(D::Minus1)?;
270        seq.broadcast_matmul(&self.inv_freq)
271    }
272}
273
274pub struct Qwen2VLVisionModel {
275    blocks: Vec<VisionBlock>,
276    patch_merger: PatchMerger,
277    patch_embed: PatchEmbed,
278    rotary_pos_emb: VisionRotaryEmbedding,
279    spatial_merge_size: usize,
280}
281
282impl Qwen2VLVisionModel {
283    pub fn new(
284        cfg: &VisionConfig,
285        vb: ShardedVarBuilder,
286        comm: &Arc<mistralrs_quant::Comm>,
287    ) -> Result<Self> {
288        let mut blocks = Vec::new();
289        for i in 0..cfg.depth {
290            blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
291        }
292
293        let patch_merger = PatchMerger::new(
294            cfg.hidden_size,
295            cfg.embed_dim,
296            cfg.spatial_merge_size,
297            vb.pp("merger"),
298        )?;
299
300        let patch_embed = PatchEmbed::new(cfg, vb.pp("patch_embed"))?;
301
302        let head_dim = cfg.embed_dim / cfg.num_heads;
303        let rotary_pos_emb = VisionRotaryEmbedding::new(head_dim / 2, vb.device())?;
304
305        Ok(Self {
306            blocks,
307            patch_embed,
308            patch_merger,
309            rotary_pos_emb,
310            spatial_merge_size: cfg.spatial_merge_size,
311        })
312    }
313
314    fn rot_pos_emb(&self, grid_thw: &Tensor, device: &Device) -> Result<Tensor> {
315        let mut pos_ids = Vec::new();
316        for i_thw in grid_thw.to_vec2::<u32>()? {
317            let (t, h, w) = (i_thw[0], i_thw[1], i_thw[2]);
318            let mut hpos_ids = Tensor::arange(0, h, device)?
319                .unsqueeze(1)?
320                .repeat((1, w as usize))?;
321            hpos_ids = hpos_ids.reshape((
322                h as usize / self.spatial_merge_size,
323                self.spatial_merge_size,
324                w as usize / self.spatial_merge_size,
325                self.spatial_merge_size,
326            ))?;
327            hpos_ids = hpos_ids.permute((0, 2, 1, 3))?;
328            hpos_ids = hpos_ids.flatten_all()?;
329
330            let mut wpos_ids = Tensor::arange(0, w, device)?
331                .unsqueeze(0)?
332                .repeat((h as usize, 1))?;
333            wpos_ids = wpos_ids.reshape((
334                h as usize / self.spatial_merge_size,
335                self.spatial_merge_size,
336                w as usize / self.spatial_merge_size,
337                self.spatial_merge_size,
338            ))?;
339            wpos_ids = wpos_ids.permute((0, 2, 1, 3))?;
340            wpos_ids = wpos_ids.flatten_all()?;
341
342            pos_ids.push(Tensor::stack(&[hpos_ids, wpos_ids], D::Minus1)?.repeat((t as usize, 1))?);
343        }
344        let pos_ids = Tensor::cat(&pos_ids, 0)?;
345        let max_grid_size = grid_thw.i((.., 1..))?.max(0)?.max(0)?.to_scalar::<u32>()?;
346        let rotary_pos_emb_full = self.rotary_pos_emb.make_embeds(max_grid_size as usize)?;
347
348        assert_eq!(pos_ids.rank(), 2);
349        rotary_pos_emb_full
350            .index_select(&pos_ids.flatten_all()?, 0)?
351            .reshape((pos_ids.dim(0)?, pos_ids.dim(1)?, ()))?
352            .flatten_from(1)
353    }
354
355    pub fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result<Tensor> {
356        let mut xs = self
357            .patch_embed
358            .forward(&xs.to_dtype(self.patch_merger.mlp0.weight().dtype())?)?;
359        let rotary_pos_emb = self.rot_pos_emb(grid_thw, xs.device())?;
360        let rotary_pos_emb = rotary_pos_emb
361            .unsqueeze(1)?
362            .repeat((1, 1, 2))?
363            .unsqueeze(0)?
364            .to_dtype(xs.dtype())?;
365
366        let grid_thw = grid_thw.to_device(&Device::Cpu)?;
367        let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
368            .repeat_interleave_flat(grid_thw.i((.., 0))?.to_vec1::<u32>()?)?
369            .to_dtype(DType::F32)?
370            .cumsum(0)?
371            .to_dtype(DType::U32)?
372            .pad_with_zeros(0, 1, 0)?
373            .to_vec1::<u32>()?;
374
375        let seq_len = xs.dim(0)?;
376        let attention_mask = match &cu_seqlens[..] {
377            &[0, len] if len == seq_len as u32 => None,
378            cu_seqlens => {
379                let mut attention_mask =
380                    Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
381                        .to_dtype(xs.dtype())?;
382                for i in 1..cu_seqlens.len() {
383                    let a = cu_seqlens[i - 1] as usize;
384                    let b = cu_seqlens[i] as usize;
385                    attention_mask = attention_mask.slice_assign(
386                        &[&.., &(a..b), &(a..b)],
387                        &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
388                    )?;
389                }
390                Some(attention_mask)
391            }
392        };
393
394        for blk in &self.blocks {
395            xs = blk.forward(&xs, attention_mask.as_ref(), &rotary_pos_emb)?;
396        }
397
398        self.patch_merger.forward(&xs)
399    }
400}