1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{LayerNorm, Linear, Module};
5use mistralrs_quant::{ColumnParallelLayer, QuantMethod, ShardedVarBuilder};
6
7use crate::{
8 layers::{self, layer_norm, Activation, Conv3dConfig, Conv3dNoBias, MatMul},
9 ops::RepeatInterleaveOp,
10};
11
12use super::config::VisionConfig;
13
14struct PatchEmbed {
15 proj: Conv3dNoBias,
16 in_channels: usize,
17 patch_size: usize,
18 temporal_patch_size: usize,
19 embed_dim: usize,
20}
21
22impl PatchEmbed {
24 fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
25 if cfg.temporal_patch_size != 2 {
26 candle_core::bail!("Only support temporal patch size of 2");
27 }
28 Ok(Self {
29 proj: Conv3dNoBias::new(
30 cfg.in_channels,
31 cfg.embed_dim,
32 [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size],
33 Conv3dConfig {
34 stride: cfg.patch_size,
35 ..Default::default()
36 },
37 vb.pp("proj"),
38 )?,
39 in_channels: cfg.in_channels,
40 patch_size: cfg.patch_size,
41 temporal_patch_size: cfg.temporal_patch_size,
42 embed_dim: cfg.embed_dim,
43 })
44 }
45
46 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
47 let xs = xs.reshape((
48 (),
49 self.in_channels,
50 self.temporal_patch_size,
51 self.patch_size,
52 self.patch_size,
53 ))?;
54 xs.apply(&self.proj)?.reshape(((), self.embed_dim))
55 }
56}
57
58struct VisionMlp {
60 fc1: Arc<dyn QuantMethod>,
61 fc2: Arc<dyn QuantMethod>,
62 act: Activation,
63}
64
65impl VisionMlp {
66 fn new(
67 dim: usize,
68 hidden_dim: usize,
69 act: Activation,
70 vb: ShardedVarBuilder,
71 comm: &Arc<mistralrs_quant::Comm>,
72 ) -> Result<Self> {
73 Ok(Self {
74 fc1: ColumnParallelLayer::new(dim, hidden_dim, &None, true, comm, vb.pp("fc1"))?,
75 fc2: ColumnParallelLayer::new(hidden_dim, dim, &None, true, comm, vb.pp("fc2"))?,
76 act,
77 })
78 }
79
80 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
81 let fc1 = self.act.forward(&self.fc1.forward(&xs.unsqueeze(0)?)?)?;
82 self.fc2.forward(&fc1)?.squeeze(0)
83 }
84}
85
86fn rotate_half(xs: &Tensor) -> Result<Tensor> {
87 let last_dim = xs.dim(D::Minus1)?;
88 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
89 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
90 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
91}
92
93fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
94 let cos = freqs.cos()?;
95 let sin = freqs.sin()?;
96
97 xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
98}
99
100struct VisionAttention {
102 qkv: Arc<dyn QuantMethod>,
103 proj: Arc<dyn QuantMethod>,
104 num_heads: usize,
105 head_dim: usize,
106}
107
108impl VisionAttention {
109 fn new(dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
110 Ok(Self {
111 qkv: mistralrs_quant::linear(dim, dim * 3, &None, vb.pp("qkv"))?,
112 proj: mistralrs_quant::linear(dim, dim, &None, vb.pp("proj"))?,
113 num_heads,
114 head_dim: dim / num_heads,
115 })
116 }
117 fn forward(
118 &self,
119 xs: &Tensor,
120 attention_mask: Option<&Tensor>,
121 rotary_pos_emb: &Tensor,
122 ) -> Result<Tensor> {
123 let seq_len = xs.dim(0)?;
124 let (mut q, mut k, mut v) = {
125 let qkv = self
126 .qkv
127 .forward(&xs.unsqueeze(0)?)?
128 .reshape((seq_len, 3, self.num_heads, ()))?
129 .permute((1, 0, 2, 3))?
130 .chunk(3, 0)?;
131 (qkv[0].squeeze(0)?, qkv[1].squeeze(0)?, qkv[2].squeeze(0)?)
132 };
133
134 q = apply_rotary_pos_emb_vision(&q.unsqueeze(0)?, rotary_pos_emb)?
135 .squeeze(0)?
136 .to_dtype(q.dtype())?;
137 k = apply_rotary_pos_emb_vision(&k.unsqueeze(0)?, rotary_pos_emb)?
138 .squeeze(0)?
139 .to_dtype(q.dtype())?;
140
141 q = q.transpose(0, 1)?.contiguous()?;
142 k = k.transpose(0, 1)?.contiguous()?;
143 v = v.transpose(0, 1)?.contiguous()?;
144
145 let att = {
146 let mut att =
147 (MatMul.matmul(&q, &k.transpose(1, 2)?)? / (self.head_dim as f64).sqrt())?;
148 att = match attention_mask {
149 Some(m) => att.broadcast_add(m)?,
150 None => att,
151 };
152 att = candle_nn::ops::softmax_last_dim(&att)?;
153 MatMul
154 .matmul(&att, &v)?
155 .transpose(0, 1)?
156 .reshape((seq_len, ()))?
157 .to_dtype(xs.dtype())?
158 };
159
160 self.proj.forward(&att.unsqueeze(0)?)?.squeeze(0)
161 }
162}
163
164struct VisionBlock {
166 norm1: LayerNorm,
167 norm2: LayerNorm,
168 mlp: VisionMlp,
169 attn: VisionAttention,
170}
171
172impl VisionBlock {
173 fn new(
174 cfg: &VisionConfig,
175 vb: ShardedVarBuilder,
176 comm: &Arc<mistralrs_quant::Comm>,
177 ) -> Result<Self> {
178 let norm1 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm1"))?;
179 let norm2 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm2"))?;
180
181 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
182 let mlp = VisionMlp::new(
183 cfg.embed_dim,
184 mlp_hidden_dim,
185 cfg.hidden_act,
186 vb.pp("mlp"),
187 comm,
188 )?;
189 let attn = VisionAttention::new(cfg.embed_dim, cfg.num_heads, vb.pp("attn"))?;
190
191 Ok(Self {
192 norm1,
193 norm2,
194 mlp,
195 attn,
196 })
197 }
198
199 fn forward(
200 &self,
201 xs: &Tensor,
202 attention_mask: Option<&Tensor>,
203 rotary_pos_emb: &Tensor,
204 ) -> Result<Tensor> {
205 let xs = (xs
206 + self
207 .attn
208 .forward(&self.norm1.forward(xs)?, attention_mask, rotary_pos_emb)?)?;
209 &xs + self.mlp.forward(&self.norm2.forward(&xs)?)?
210 }
211}
212
213struct PatchMerger {
214 ln_q: LayerNorm,
215 mlp0: Linear,
216 mlp2: Linear,
217 hidden_size: usize,
218}
219
220impl PatchMerger {
221 pub fn new(
222 dim: usize,
223 context_dim: usize,
224 spatial_merge_size: usize,
225 vb: ShardedVarBuilder,
226 ) -> Result<Self> {
227 let hidden_size = context_dim * spatial_merge_size.pow(2);
228 let mlp0 = layers::linear(hidden_size, hidden_size, vb.pp("mlp.0"))?;
229 let mlp2 = layers::linear(hidden_size, dim, vb.pp("mlp.2"))?;
230 Ok(Self {
231 ln_q: layer_norm(context_dim, 1e-6, vb.pp("ln_q"))?,
232 mlp0,
233 mlp2,
234 hidden_size,
235 })
236 }
237
238 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
239 xs.unsqueeze(0)?
240 .apply(&self.ln_q)?
241 .reshape(((), self.hidden_size))?
242 .apply(&self.mlp0)?
243 .gelu()?
244 .apply(&self.mlp2)?
245 .squeeze(0)
246 }
247}
248
249struct VisionRotaryEmbedding {
250 inv_freq: Tensor,
251}
252
253impl VisionRotaryEmbedding {
254 const THETA: f32 = 10000.;
255
256 fn new(dim: usize, device: &Device) -> Result<Self> {
257 let inv_freq = (0..dim)
258 .step_by(2)
259 .map(|i| 1f32 / Self::THETA.powf(i as f32 / dim as f32))
260 .collect::<Vec<_>>();
261 let inv_freq_len = inv_freq.len();
262 Ok(Self {
263 inv_freq: Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?,
264 })
265 }
266
267 fn make_embeds(&self, seqlen: usize) -> Result<Tensor> {
268 let seq =
269 Tensor::arange(0f32, seqlen as f32, self.inv_freq.device())?.unsqueeze(D::Minus1)?;
270 seq.broadcast_matmul(&self.inv_freq)
271 }
272}
273
274pub struct Qwen2VLVisionModel {
275 blocks: Vec<VisionBlock>,
276 patch_merger: PatchMerger,
277 patch_embed: PatchEmbed,
278 rotary_pos_emb: VisionRotaryEmbedding,
279 spatial_merge_size: usize,
280}
281
282impl Qwen2VLVisionModel {
283 pub fn new(
284 cfg: &VisionConfig,
285 vb: ShardedVarBuilder,
286 comm: &Arc<mistralrs_quant::Comm>,
287 ) -> Result<Self> {
288 let mut blocks = Vec::new();
289 for i in 0..cfg.depth {
290 blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
291 }
292
293 let patch_merger = PatchMerger::new(
294 cfg.hidden_size,
295 cfg.embed_dim,
296 cfg.spatial_merge_size,
297 vb.pp("merger"),
298 )?;
299
300 let patch_embed = PatchEmbed::new(cfg, vb.pp("patch_embed"))?;
301
302 let head_dim = cfg.embed_dim / cfg.num_heads;
303 let rotary_pos_emb = VisionRotaryEmbedding::new(head_dim / 2, vb.device())?;
304
305 Ok(Self {
306 blocks,
307 patch_embed,
308 patch_merger,
309 rotary_pos_emb,
310 spatial_merge_size: cfg.spatial_merge_size,
311 })
312 }
313
314 fn rot_pos_emb(&self, grid_thw: &Tensor, device: &Device) -> Result<Tensor> {
315 let mut pos_ids = Vec::new();
316 for i_thw in grid_thw.to_vec2::<u32>()? {
317 let (t, h, w) = (i_thw[0], i_thw[1], i_thw[2]);
318 let mut hpos_ids = Tensor::arange(0, h, device)?
319 .unsqueeze(1)?
320 .repeat((1, w as usize))?;
321 hpos_ids = hpos_ids.reshape((
322 h as usize / self.spatial_merge_size,
323 self.spatial_merge_size,
324 w as usize / self.spatial_merge_size,
325 self.spatial_merge_size,
326 ))?;
327 hpos_ids = hpos_ids.permute((0, 2, 1, 3))?;
328 hpos_ids = hpos_ids.flatten_all()?;
329
330 let mut wpos_ids = Tensor::arange(0, w, device)?
331 .unsqueeze(0)?
332 .repeat((h as usize, 1))?;
333 wpos_ids = wpos_ids.reshape((
334 h as usize / self.spatial_merge_size,
335 self.spatial_merge_size,
336 w as usize / self.spatial_merge_size,
337 self.spatial_merge_size,
338 ))?;
339 wpos_ids = wpos_ids.permute((0, 2, 1, 3))?;
340 wpos_ids = wpos_ids.flatten_all()?;
341
342 pos_ids.push(Tensor::stack(&[hpos_ids, wpos_ids], D::Minus1)?.repeat((t as usize, 1))?);
343 }
344 let pos_ids = Tensor::cat(&pos_ids, 0)?;
345 let max_grid_size = grid_thw.i((.., 1..))?.max(0)?.max(0)?.to_scalar::<u32>()?;
346 let rotary_pos_emb_full = self.rotary_pos_emb.make_embeds(max_grid_size as usize)?;
347
348 assert_eq!(pos_ids.rank(), 2);
349 rotary_pos_emb_full
350 .index_select(&pos_ids.flatten_all()?, 0)?
351 .reshape((pos_ids.dim(0)?, pos_ids.dim(1)?, ()))?
352 .flatten_from(1)
353 }
354
355 pub fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result<Tensor> {
356 let mut xs = self
357 .patch_embed
358 .forward(&xs.to_dtype(self.patch_merger.mlp0.weight().dtype())?)?;
359 let rotary_pos_emb = self.rot_pos_emb(grid_thw, xs.device())?;
360 let rotary_pos_emb = rotary_pos_emb
361 .unsqueeze(1)?
362 .repeat((1, 1, 2))?
363 .unsqueeze(0)?
364 .to_dtype(xs.dtype())?;
365
366 let grid_thw = grid_thw.to_device(&Device::Cpu)?;
367 let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
368 .repeat_interleave_flat(grid_thw.i((.., 0))?.to_vec1::<u32>()?)?
369 .to_dtype(DType::F32)?
370 .cumsum(0)?
371 .to_dtype(DType::U32)?
372 .pad_with_zeros(0, 1, 0)?
373 .to_vec1::<u32>()?;
374
375 let seq_len = xs.dim(0)?;
376 let attention_mask = match &cu_seqlens[..] {
377 &[0, len] if len == seq_len as u32 => None,
378 cu_seqlens => {
379 let mut attention_mask =
380 Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
381 .to_dtype(xs.dtype())?;
382 for i in 1..cu_seqlens.len() {
383 let a = cu_seqlens[i - 1] as usize;
384 let b = cu_seqlens[i] as usize;
385 attention_mask = attention_mask.slice_assign(
386 &[&.., &(a..b), &(a..b)],
387 &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
388 )?;
389 }
390 Some(attention_mask)
391 }
392 };
393
394 for blk in &self.blocks {
395 xs = blk.forward(&xs, attention_mask.as_ref(), &rotary_pos_emb)?;
396 }
397
398 self.patch_merger.forward(&xs)
399 }
400}