mistralrs_core/vision_models/qwen2_5_vl/
vision.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Linear, Module};
5use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
6
7use crate::{
8    layers::{self, Activation, Conv3dConfig, Conv3dNoBias, MatMul, RmsNorm},
9    ops::RepeatInterleaveOp,
10};
11
12use super::config::VisionConfig;
13
14struct PatchEmbed {
15    proj: Conv3dNoBias,
16    in_channels: usize,
17    patch_size: usize,
18    temporal_patch_size: usize,
19    hidden_size: usize,
20}
21
22// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L272
23impl PatchEmbed {
24    fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
25        if cfg.temporal_patch_size != 2 {
26            candle_core::bail!("Only support temporal patch size of 2");
27        }
28        Ok(Self {
29            proj: Conv3dNoBias::new(
30                cfg.in_chans,
31                cfg.hidden_size,
32                [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size],
33                Conv3dConfig {
34                    stride: cfg.patch_size,
35                    ..Default::default()
36                },
37                vb.pp("proj"),
38            )?,
39            in_channels: cfg.in_chans,
40            patch_size: cfg.patch_size,
41            temporal_patch_size: cfg.temporal_patch_size,
42            hidden_size: cfg.hidden_size,
43        })
44    }
45
46    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
47        let xs = xs.reshape((
48            (),
49            self.in_channels,
50            self.temporal_patch_size,
51            self.patch_size,
52            self.patch_size,
53        ))?;
54        xs.apply(&self.proj)?.reshape(((), self.hidden_size))
55    }
56}
57
58// https://github.com/huggingface/transformers/blob/6a1ab634b6886b6560b0502e7a305c8cd881732e/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L75
59struct VisionMlp {
60    gate_proj: Arc<dyn QuantMethod>,
61    up_proj: Arc<dyn QuantMethod>,
62    down_proj: Arc<dyn QuantMethod>,
63    act: Activation,
64}
65
66impl VisionMlp {
67    fn new(
68        dim: usize,
69        hidden_dim: usize,
70        act: Activation,
71        vb: ShardedVarBuilder,
72        comm: &Arc<mistralrs_quant::Comm>,
73    ) -> Result<Self> {
74        Ok(Self {
75            gate_proj: ColumnParallelLayer::new(
76                dim,
77                hidden_dim,
78                &None,
79                true,
80                comm,
81                vb.pp("gate_proj"),
82            )?,
83            up_proj: ColumnParallelLayer::new(
84                dim,
85                hidden_dim,
86                &None,
87                true,
88                comm,
89                vb.pp("up_proj"),
90            )?,
91            down_proj: RowParallelLayer::new(
92                hidden_dim,
93                dim,
94                &None,
95                true,
96                comm,
97                vb.pp("down_proj"),
98            )?,
99            act,
100        })
101    }
102
103    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
104        let original_dtype = xs.dtype();
105        let mut xs = xs.clone();
106        if let Some(t) = self.gate_proj.quantized_act_type() {
107            xs = xs.to_dtype(t)?;
108        }
109        let lhs = self
110            .gate_proj
111            .forward(&xs.unsqueeze(0)?)?
112            .apply(&self.act)?;
113        let rhs = self.up_proj.forward(&xs.unsqueeze(0)?)?;
114        let mut res = self.down_proj.forward(&(lhs * rhs)?)?;
115
116        res = res.squeeze(0)?;
117        if self.gate_proj.quantized_act_type().is_some() {
118            res.to_dtype(original_dtype)?;
119        }
120        Ok(res)
121    }
122}
123
124fn rotate_half(xs: &Tensor) -> Result<Tensor> {
125    let last_dim = xs.dim(D::Minus1)?;
126    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
127    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
128    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
129}
130
131fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
132    let cos = freqs.cos()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;
133    let sin = freqs.sin()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;
134
135    xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
136}
137
138// https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325
139struct VisionAttention {
140    qkv: Arc<dyn QuantMethod>,
141    proj: Arc<dyn QuantMethod>,
142    num_heads: usize,
143    head_dim: usize,
144}
145
146impl VisionAttention {
147    fn new(dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
148        Ok(Self {
149            qkv: mistralrs_quant::linear(dim, dim * 3, &None, vb.pp("qkv"))?,
150            proj: mistralrs_quant::linear(dim, dim, &None, vb.pp("proj"))?,
151            num_heads,
152            head_dim: dim / num_heads,
153        })
154    }
155    fn forward(
156        &self,
157        xs: &Tensor,
158        attention_mask: Option<&Tensor>,
159        rotary_pos_emb: &Tensor,
160    ) -> Result<Tensor> {
161        let seq_len = xs.dim(0)?;
162        let (mut q, mut k, mut v) = {
163            let qkv = self
164                .qkv
165                .forward(&xs.unsqueeze(0)?)?
166                .reshape((seq_len, 3, self.num_heads, ()))?
167                .permute((1, 0, 2, 3))?
168                .chunk(3, 0)?;
169            (qkv[0].squeeze(0)?, qkv[1].squeeze(0)?, qkv[2].squeeze(0)?)
170        };
171
172        q = apply_rotary_pos_emb_vision(&q.unsqueeze(0)?, rotary_pos_emb)?
173            .squeeze(0)?
174            .to_dtype(q.dtype())?;
175        k = apply_rotary_pos_emb_vision(&k.unsqueeze(0)?, rotary_pos_emb)?
176            .squeeze(0)?
177            .to_dtype(q.dtype())?;
178
179        q = q.transpose(0, 1)?.contiguous()?;
180        k = k.transpose(0, 1)?.contiguous()?;
181        v = v.transpose(0, 1)?.contiguous()?;
182
183        let att = {
184            let mut att =
185                (MatMul.matmul(&q, &k.transpose(1, 2)?)? / (self.head_dim as f64).sqrt())?;
186            att = match attention_mask {
187                Some(m) => att.broadcast_add(m)?,
188                None => att,
189            };
190            att = candle_nn::ops::softmax_last_dim(&att)?;
191            MatMul
192                .matmul(&att, &v)?
193                .transpose(0, 1)?
194                .reshape((seq_len, ()))?
195                .to_dtype(xs.dtype())?
196        };
197
198        self.proj.forward(&att.unsqueeze(0)?)?.squeeze(0)
199    }
200}
201
202// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418
203struct VisionBlock {
204    norm1: RmsNorm,
205    norm2: RmsNorm,
206    mlp: VisionMlp,
207    attn: VisionAttention,
208}
209
210impl VisionBlock {
211    fn new(
212        cfg: &VisionConfig,
213        vb: ShardedVarBuilder,
214        comm: &Arc<mistralrs_quant::Comm>,
215    ) -> Result<Self> {
216        let norm1 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm1"))?;
217        let norm2 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm2"))?;
218
219        let mlp = VisionMlp::new(
220            cfg.hidden_size,
221            cfg.intermediate_size,
222            cfg.hidden_act,
223            vb.pp("mlp"),
224            comm,
225        )?;
226        let attn = VisionAttention::new(cfg.hidden_size, cfg.num_heads, vb.pp("attn"))?;
227
228        Ok(Self {
229            norm1,
230            norm2,
231            mlp,
232            attn,
233        })
234    }
235
236    fn forward(
237        &self,
238        xs: &Tensor,
239        attention_mask: Option<&Tensor>,
240        rotary_pos_emb: &Tensor,
241    ) -> Result<Tensor> {
242        let xs = (xs
243            + self
244                .attn
245                .forward(&self.norm1.forward(xs)?, attention_mask, rotary_pos_emb)?)?;
246        &xs + self.mlp.forward(&self.norm2.forward(&xs)?)?
247    }
248}
249
250struct PatchMerger {
251    ln_q: RmsNorm,
252    mlp0: Linear,
253    mlp2: Linear,
254    out_hidden_size: usize,
255}
256
257impl PatchMerger {
258    pub fn new(
259        dim: usize,
260        context_dim: usize,
261        spatial_merge_size: usize,
262        vb: ShardedVarBuilder,
263    ) -> Result<Self> {
264        let out_hidden_size = context_dim * spatial_merge_size.pow(2);
265        let mlp0 = layers::linear(out_hidden_size, out_hidden_size, vb.pp("mlp.0"))?;
266        let mlp2 = layers::linear(out_hidden_size, dim, vb.pp("mlp.2"))?;
267        Ok(Self {
268            ln_q: RmsNorm::new(context_dim, 1e-6, vb.pp("ln_q"))?,
269            mlp0,
270            mlp2,
271            out_hidden_size,
272        })
273    }
274
275    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
276        xs.unsqueeze(0)?
277            .apply(&self.ln_q)?
278            .reshape(((), self.out_hidden_size))?
279            .apply(&self.mlp0)?
280            .gelu()?
281            .apply(&self.mlp2)?
282            .squeeze(0)
283    }
284}
285
286struct VisionRotaryEmbedding {
287    inv_freq: Tensor,
288}
289
290impl VisionRotaryEmbedding {
291    const THETA: f32 = 10000.;
292
293    fn new(dim: usize, device: &Device) -> Result<Self> {
294        let inv_freq = (0..dim)
295            .step_by(2)
296            .map(|i| 1f32 / Self::THETA.powf(i as f32 / dim as f32))
297            .collect::<Vec<_>>();
298        let inv_freq_len = inv_freq.len();
299        Ok(Self {
300            inv_freq: Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?,
301        })
302    }
303
304    fn make_embeds(&self, seqlen: usize) -> Result<Tensor> {
305        let seq =
306            Tensor::arange(0f32, seqlen as f32, self.inv_freq.device())?.unsqueeze(D::Minus1)?;
307        seq.broadcast_matmul(&self.inv_freq)
308    }
309}
310
311pub struct Qwen2_5VLVisionModel {
312    blocks: Vec<VisionBlock>,
313    patch_merger: PatchMerger,
314    patch_embed: PatchEmbed,
315    rotary_pos_emb: VisionRotaryEmbedding,
316    spatial_merge_size: usize,
317    spatial_merge_unit: usize,
318    window_size: usize,
319    patch_size: usize,
320    fullatt_block_indices: Vec<usize>,
321}
322
323impl Qwen2_5VLVisionModel {
324    pub fn new(
325        cfg: &VisionConfig,
326        vb: ShardedVarBuilder,
327        comm: &Arc<mistralrs_quant::Comm>,
328    ) -> Result<Self> {
329        let mut blocks = Vec::new();
330        for i in 0..cfg.depth {
331            blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
332        }
333
334        let patch_merger = PatchMerger::new(
335            cfg.out_hidden_size,
336            cfg.hidden_size,
337            cfg.spatial_merge_size,
338            vb.pp("merger"),
339        )?;
340
341        let patch_embed = PatchEmbed::new(cfg, vb.pp("patch_embed"))?;
342
343        let head_dim = cfg.hidden_size / cfg.num_heads;
344        let rotary_pos_emb = VisionRotaryEmbedding::new(head_dim / 2, vb.device())?;
345
346        Ok(Self {
347            blocks,
348            patch_embed,
349            patch_merger,
350            rotary_pos_emb,
351            spatial_merge_size: cfg.spatial_merge_size,
352            spatial_merge_unit: cfg.spatial_merge_size * cfg.spatial_merge_size,
353            window_size: cfg.window_size,
354            patch_size: cfg.patch_size,
355            fullatt_block_indices: cfg.fullatt_block_indexes.clone(),
356        })
357    }
358
359    fn rot_pos_emb(&self, grid_thw: &Tensor, device: &Device) -> Result<Tensor> {
360        let mut pos_ids = Vec::new();
361        for i_thw in grid_thw.to_vec2::<u32>()? {
362            let (t, h, w) = (i_thw[0], i_thw[1], i_thw[2]);
363            let mut hpos_ids = Tensor::arange(0, h, device)?
364                .unsqueeze(1)?
365                .repeat((1, w as usize))?;
366            hpos_ids = hpos_ids.reshape((
367                h as usize / self.spatial_merge_size,
368                self.spatial_merge_size,
369                w as usize / self.spatial_merge_size,
370                self.spatial_merge_size,
371            ))?;
372            hpos_ids = hpos_ids.permute((0, 2, 1, 3))?;
373            hpos_ids = hpos_ids.flatten_all()?;
374
375            let mut wpos_ids = Tensor::arange(0, w, device)?
376                .unsqueeze(0)?
377                .repeat((h as usize, 1))?;
378            wpos_ids = wpos_ids.reshape((
379                h as usize / self.spatial_merge_size,
380                self.spatial_merge_size,
381                w as usize / self.spatial_merge_size,
382                self.spatial_merge_size,
383            ))?;
384            wpos_ids = wpos_ids.permute((0, 2, 1, 3))?;
385            wpos_ids = wpos_ids.flatten_all()?;
386
387            pos_ids.push(Tensor::stack(&[hpos_ids, wpos_ids], D::Minus1)?.repeat((t as usize, 1))?);
388        }
389        let pos_ids = Tensor::cat(&pos_ids, 0)?;
390        let max_grid_size = grid_thw.i((.., 1..))?.max(0)?.max(0)?.to_scalar::<u32>()?;
391        let rotary_pos_emb_full = self.rotary_pos_emb.make_embeds(max_grid_size as usize)?;
392
393        assert_eq!(pos_ids.rank(), 2);
394        rotary_pos_emb_full
395            .index_select(&pos_ids.flatten_all()?, 0)?
396            .reshape((pos_ids.dim(0)?, pos_ids.dim(1)?, ()))?
397            .flatten_from(1)
398    }
399
400    fn get_window_index(&self, grid_thw: &Tensor, device: &Device) -> Result<(Tensor, Vec<i64>)> {
401        const PADDING_VALUE: i32 = -100;
402        let mut window_index = Vec::new();
403        let mut cu_window_seqlens = vec![0];
404        let mut window_index_id = 0;
405        let vit_merger_window_size = self.window_size / self.spatial_merge_size / self.patch_size;
406
407        for i_thw in grid_thw.to_vec2::<u32>()? {
408            let (t, h, w) = (i_thw[0] as usize, i_thw[1] as usize, i_thw[2] as usize);
409            let llm_grid_h = h / self.spatial_merge_size;
410            let llm_grid_w = w / self.spatial_merge_size;
411            let index = Tensor::arange(0i32, (t * llm_grid_h * llm_grid_w) as i32, &Device::Cpu)?
412                .reshape((t, llm_grid_h, llm_grid_w))?;
413            let pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size;
414            let pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size;
415            let num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size;
416            let num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size;
417            let index_padded = {
418                let h = Tensor::full(PADDING_VALUE, (t, pad_h, llm_grid_w), &Device::Cpu)?;
419                let w = Tensor::full(PADDING_VALUE, (t, pad_h + llm_grid_h, pad_w), &Device::Cpu)?;
420                let mut index = Tensor::cat(&[index, h], D::Minus2)?;
421                index = Tensor::cat(&[index, w], D::Minus1)?;
422                index = index.reshape((
423                    t,
424                    num_windows_h,
425                    vit_merger_window_size,
426                    num_windows_w,
427                    vit_merger_window_size,
428                ))?;
429                index = index.permute((0, 1, 3, 2, 4))?.reshape((
430                    t,
431                    num_windows_h * num_windows_w,
432                    vit_merger_window_size,
433                    vit_merger_window_size,
434                ))?;
435                index
436            };
437            let seqlens = index_padded
438                .ne(PADDING_VALUE)?
439                .to_dtype(index_padded.dtype())?
440                .sum((2, 3))?
441                .flatten_all()?;
442            let index_new = index_padded
443                .flatten_all()?
444                .to_vec1::<i32>()?
445                .into_iter()
446                .filter(|x| *x != PADDING_VALUE)
447                .collect::<Vec<_>>();
448            window_index.push(Tensor::new(
449                index_new
450                    .iter()
451                    .map(|x| x + window_index_id)
452                    .collect::<Vec<_>>(),
453                device,
454            )?);
455            let cu_seqlens_tmp = ((seqlens
456                .to_dtype(DType::F32)?
457                .cumsum(0)?
458                .to_dtype(seqlens.dtype())?
459                * self.spatial_merge_unit as f64)?
460                + cu_window_seqlens[cu_window_seqlens.len() - 1] as f64)?;
461            cu_window_seqlens.extend(cu_seqlens_tmp.to_vec1::<i64>()?);
462            window_index_id += (t * llm_grid_h * llm_grid_w) as i32;
463        }
464
465        Ok((Tensor::cat(&window_index, 0)?, cu_window_seqlens))
466    }
467
468    pub fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result<Tensor> {
469        let xs = self
470            .patch_embed
471            .forward(&xs.to_dtype(self.patch_merger.mlp0.weight().dtype())?)?;
472        let rotary_pos_emb = self.rot_pos_emb(grid_thw, xs.device())?;
473        let (window_index, mut cu_window_seqlens) = self.get_window_index(grid_thw, xs.device())?;
474        cu_window_seqlens.dedup();
475
476        let seq_len = xs.dims2()?.0;
477        let mut xs = xs.reshape((
478            seq_len / self.spatial_merge_unit,
479            self.spatial_merge_unit,
480            (),
481        ))?;
482        xs = xs.index_select(&window_index, 0)?;
483        xs = xs.reshape((seq_len, ()))?;
484        let mut rotary_pos_emb = rotary_pos_emb.reshape((
485            seq_len / self.spatial_merge_unit,
486            self.spatial_merge_unit,
487            (),
488        ))?;
489        rotary_pos_emb = rotary_pos_emb.index_select(&window_index, 0)?;
490        rotary_pos_emb = rotary_pos_emb.reshape((seq_len, ()))?;
491        rotary_pos_emb = Tensor::cat(&[&rotary_pos_emb; 2], D::Minus1)?;
492        rotary_pos_emb = rotary_pos_emb.to_dtype(xs.dtype())?;
493
494        let grid_thw = grid_thw.to_device(&Device::Cpu)?;
495        let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
496            .repeat_interleave_flat(grid_thw.i((.., 0))?.to_vec1::<u32>()?)?
497            .to_dtype(DType::F32)?
498            .cumsum(0)?
499            .to_dtype(DType::U32)?
500            .pad_with_zeros(0, 1, 0)?
501            .to_vec1::<u32>()?;
502
503        let seq_len = xs.dim(0)?;
504        let attention_mask_full = match &cu_seqlens[..] {
505            &[0, len] if len == seq_len as u32 => None,
506            cu_seqlens => {
507                let mut attention_mask =
508                    Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
509                        .to_dtype(xs.dtype())?;
510                for i in 1..cu_seqlens.len() {
511                    let a = cu_seqlens[i - 1] as usize;
512                    let b = cu_seqlens[i] as usize;
513                    attention_mask = attention_mask.slice_assign(
514                        &[&.., &(a..b), &(a..b)],
515                        &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
516                    )?;
517                }
518                Some(attention_mask)
519            }
520        };
521        let attention_mask_window = match &cu_window_seqlens[..] {
522            &[0, len] if len == seq_len as i64 => None,
523            cu_seqlens => {
524                let mut attention_mask =
525                    Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
526                        .to_dtype(xs.dtype())?;
527                for i in 1..cu_seqlens.len() {
528                    let a = cu_seqlens[i - 1] as usize;
529                    let b = cu_seqlens[i] as usize;
530                    attention_mask = attention_mask.slice_assign(
531                        &[&.., &(a..b), &(a..b)],
532                        &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
533                    )?;
534                }
535                Some(attention_mask)
536            }
537        };
538
539        for (i, blk) in self.blocks.iter().enumerate() {
540            let attention_mask = if self.fullatt_block_indices.contains(&i) {
541                attention_mask_full.as_ref()
542            } else {
543                attention_mask_window.as_ref()
544            };
545            xs = blk.forward(&xs, attention_mask, &rotary_pos_emb)?;
546        }
547
548        xs = self.patch_merger.forward(&xs)?;
549        let reverse_indices = window_index.arg_sort_last_dim(true)?;
550        xs.index_select(&reverse_indices, 0)
551    }
552}