1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Linear, Module};
5use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
6
7use crate::{
8 layers::{self, Activation, Conv3dConfig, Conv3dNoBias, MatMul, RmsNorm},
9 ops::RepeatInterleaveOp,
10};
11
12use super::config::VisionConfig;
13
14struct PatchEmbed {
15 proj: Conv3dNoBias,
16 in_channels: usize,
17 patch_size: usize,
18 temporal_patch_size: usize,
19 hidden_size: usize,
20}
21
22impl PatchEmbed {
24 fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
25 if cfg.temporal_patch_size != 2 {
26 candle_core::bail!("Only support temporal patch size of 2");
27 }
28 Ok(Self {
29 proj: Conv3dNoBias::new(
30 cfg.in_chans,
31 cfg.hidden_size,
32 [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size],
33 Conv3dConfig {
34 stride: cfg.patch_size,
35 ..Default::default()
36 },
37 vb.pp("proj"),
38 )?,
39 in_channels: cfg.in_chans,
40 patch_size: cfg.patch_size,
41 temporal_patch_size: cfg.temporal_patch_size,
42 hidden_size: cfg.hidden_size,
43 })
44 }
45
46 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
47 let xs = xs.reshape((
48 (),
49 self.in_channels,
50 self.temporal_patch_size,
51 self.patch_size,
52 self.patch_size,
53 ))?;
54 xs.apply(&self.proj)?.reshape(((), self.hidden_size))
55 }
56}
57
58struct VisionMlp {
60 gate_proj: Arc<dyn QuantMethod>,
61 up_proj: Arc<dyn QuantMethod>,
62 down_proj: Arc<dyn QuantMethod>,
63 act: Activation,
64}
65
66impl VisionMlp {
67 fn new(
68 dim: usize,
69 hidden_dim: usize,
70 act: Activation,
71 vb: ShardedVarBuilder,
72 comm: &Arc<mistralrs_quant::Comm>,
73 ) -> Result<Self> {
74 Ok(Self {
75 gate_proj: ColumnParallelLayer::new(
76 dim,
77 hidden_dim,
78 &None,
79 true,
80 comm,
81 vb.pp("gate_proj"),
82 )?,
83 up_proj: ColumnParallelLayer::new(
84 dim,
85 hidden_dim,
86 &None,
87 true,
88 comm,
89 vb.pp("up_proj"),
90 )?,
91 down_proj: RowParallelLayer::new(
92 hidden_dim,
93 dim,
94 &None,
95 true,
96 comm,
97 vb.pp("down_proj"),
98 )?,
99 act,
100 })
101 }
102
103 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
104 let original_dtype = xs.dtype();
105 let mut xs = xs.clone();
106 if let Some(t) = self.gate_proj.quantized_act_type() {
107 xs = xs.to_dtype(t)?;
108 }
109 let lhs = self
110 .gate_proj
111 .forward(&xs.unsqueeze(0)?)?
112 .apply(&self.act)?;
113 let rhs = self.up_proj.forward(&xs.unsqueeze(0)?)?;
114 let mut res = self.down_proj.forward(&(lhs * rhs)?)?;
115
116 res = res.squeeze(0)?;
117 if self.gate_proj.quantized_act_type().is_some() {
118 res.to_dtype(original_dtype)?;
119 }
120 Ok(res)
121 }
122}
123
124fn rotate_half(xs: &Tensor) -> Result<Tensor> {
125 let last_dim = xs.dim(D::Minus1)?;
126 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
127 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
128 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
129}
130
131fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
132 let cos = freqs.cos()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;
133 let sin = freqs.sin()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;
134
135 xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
136}
137
138struct VisionAttention {
140 qkv: Arc<dyn QuantMethod>,
141 proj: Arc<dyn QuantMethod>,
142 num_heads: usize,
143 head_dim: usize,
144}
145
146impl VisionAttention {
147 fn new(dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
148 Ok(Self {
149 qkv: mistralrs_quant::linear(dim, dim * 3, &None, vb.pp("qkv"))?,
150 proj: mistralrs_quant::linear(dim, dim, &None, vb.pp("proj"))?,
151 num_heads,
152 head_dim: dim / num_heads,
153 })
154 }
155 fn forward(
156 &self,
157 xs: &Tensor,
158 attention_mask: Option<&Tensor>,
159 rotary_pos_emb: &Tensor,
160 ) -> Result<Tensor> {
161 let seq_len = xs.dim(0)?;
162 let (mut q, mut k, mut v) = {
163 let qkv = self
164 .qkv
165 .forward(&xs.unsqueeze(0)?)?
166 .reshape((seq_len, 3, self.num_heads, ()))?
167 .permute((1, 0, 2, 3))?
168 .chunk(3, 0)?;
169 (qkv[0].squeeze(0)?, qkv[1].squeeze(0)?, qkv[2].squeeze(0)?)
170 };
171
172 q = apply_rotary_pos_emb_vision(&q.unsqueeze(0)?, rotary_pos_emb)?
173 .squeeze(0)?
174 .to_dtype(q.dtype())?;
175 k = apply_rotary_pos_emb_vision(&k.unsqueeze(0)?, rotary_pos_emb)?
176 .squeeze(0)?
177 .to_dtype(q.dtype())?;
178
179 q = q.transpose(0, 1)?.contiguous()?;
180 k = k.transpose(0, 1)?.contiguous()?;
181 v = v.transpose(0, 1)?.contiguous()?;
182
183 let att = {
184 let mut att =
185 (MatMul.matmul(&q, &k.transpose(1, 2)?)? / (self.head_dim as f64).sqrt())?;
186 att = match attention_mask {
187 Some(m) => att.broadcast_add(m)?,
188 None => att,
189 };
190 att = candle_nn::ops::softmax_last_dim(&att)?;
191 MatMul
192 .matmul(&att, &v)?
193 .transpose(0, 1)?
194 .reshape((seq_len, ()))?
195 .to_dtype(xs.dtype())?
196 };
197
198 self.proj.forward(&att.unsqueeze(0)?)?.squeeze(0)
199 }
200}
201
202struct VisionBlock {
204 norm1: RmsNorm,
205 norm2: RmsNorm,
206 mlp: VisionMlp,
207 attn: VisionAttention,
208}
209
210impl VisionBlock {
211 fn new(
212 cfg: &VisionConfig,
213 vb: ShardedVarBuilder,
214 comm: &Arc<mistralrs_quant::Comm>,
215 ) -> Result<Self> {
216 let norm1 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm1"))?;
217 let norm2 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm2"))?;
218
219 let mlp = VisionMlp::new(
220 cfg.hidden_size,
221 cfg.intermediate_size,
222 cfg.hidden_act,
223 vb.pp("mlp"),
224 comm,
225 )?;
226 let attn = VisionAttention::new(cfg.hidden_size, cfg.num_heads, vb.pp("attn"))?;
227
228 Ok(Self {
229 norm1,
230 norm2,
231 mlp,
232 attn,
233 })
234 }
235
236 fn forward(
237 &self,
238 xs: &Tensor,
239 attention_mask: Option<&Tensor>,
240 rotary_pos_emb: &Tensor,
241 ) -> Result<Tensor> {
242 let xs = (xs
243 + self
244 .attn
245 .forward(&self.norm1.forward(xs)?, attention_mask, rotary_pos_emb)?)?;
246 &xs + self.mlp.forward(&self.norm2.forward(&xs)?)?
247 }
248}
249
250struct PatchMerger {
251 ln_q: RmsNorm,
252 mlp0: Linear,
253 mlp2: Linear,
254 out_hidden_size: usize,
255}
256
257impl PatchMerger {
258 pub fn new(
259 dim: usize,
260 context_dim: usize,
261 spatial_merge_size: usize,
262 vb: ShardedVarBuilder,
263 ) -> Result<Self> {
264 let out_hidden_size = context_dim * spatial_merge_size.pow(2);
265 let mlp0 = layers::linear(out_hidden_size, out_hidden_size, vb.pp("mlp.0"))?;
266 let mlp2 = layers::linear(out_hidden_size, dim, vb.pp("mlp.2"))?;
267 Ok(Self {
268 ln_q: RmsNorm::new(context_dim, 1e-6, vb.pp("ln_q"))?,
269 mlp0,
270 mlp2,
271 out_hidden_size,
272 })
273 }
274
275 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
276 xs.unsqueeze(0)?
277 .apply(&self.ln_q)?
278 .reshape(((), self.out_hidden_size))?
279 .apply(&self.mlp0)?
280 .gelu()?
281 .apply(&self.mlp2)?
282 .squeeze(0)
283 }
284}
285
286struct VisionRotaryEmbedding {
287 inv_freq: Tensor,
288}
289
290impl VisionRotaryEmbedding {
291 const THETA: f32 = 10000.;
292
293 fn new(dim: usize, device: &Device) -> Result<Self> {
294 let inv_freq = (0..dim)
295 .step_by(2)
296 .map(|i| 1f32 / Self::THETA.powf(i as f32 / dim as f32))
297 .collect::<Vec<_>>();
298 let inv_freq_len = inv_freq.len();
299 Ok(Self {
300 inv_freq: Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?,
301 })
302 }
303
304 fn make_embeds(&self, seqlen: usize) -> Result<Tensor> {
305 let seq =
306 Tensor::arange(0f32, seqlen as f32, self.inv_freq.device())?.unsqueeze(D::Minus1)?;
307 seq.broadcast_matmul(&self.inv_freq)
308 }
309}
310
311pub struct Qwen2_5VLVisionModel {
312 blocks: Vec<VisionBlock>,
313 patch_merger: PatchMerger,
314 patch_embed: PatchEmbed,
315 rotary_pos_emb: VisionRotaryEmbedding,
316 spatial_merge_size: usize,
317 spatial_merge_unit: usize,
318 window_size: usize,
319 patch_size: usize,
320 fullatt_block_indices: Vec<usize>,
321}
322
323impl Qwen2_5VLVisionModel {
324 pub fn new(
325 cfg: &VisionConfig,
326 vb: ShardedVarBuilder,
327 comm: &Arc<mistralrs_quant::Comm>,
328 ) -> Result<Self> {
329 let mut blocks = Vec::new();
330 for i in 0..cfg.depth {
331 blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
332 }
333
334 let patch_merger = PatchMerger::new(
335 cfg.out_hidden_size,
336 cfg.hidden_size,
337 cfg.spatial_merge_size,
338 vb.pp("merger"),
339 )?;
340
341 let patch_embed = PatchEmbed::new(cfg, vb.pp("patch_embed"))?;
342
343 let head_dim = cfg.hidden_size / cfg.num_heads;
344 let rotary_pos_emb = VisionRotaryEmbedding::new(head_dim / 2, vb.device())?;
345
346 Ok(Self {
347 blocks,
348 patch_embed,
349 patch_merger,
350 rotary_pos_emb,
351 spatial_merge_size: cfg.spatial_merge_size,
352 spatial_merge_unit: cfg.spatial_merge_size * cfg.spatial_merge_size,
353 window_size: cfg.window_size,
354 patch_size: cfg.patch_size,
355 fullatt_block_indices: cfg.fullatt_block_indexes.clone(),
356 })
357 }
358
359 fn rot_pos_emb(&self, grid_thw: &Tensor, device: &Device) -> Result<Tensor> {
360 let mut pos_ids = Vec::new();
361 for i_thw in grid_thw.to_vec2::<u32>()? {
362 let (t, h, w) = (i_thw[0], i_thw[1], i_thw[2]);
363 let mut hpos_ids = Tensor::arange(0, h, device)?
364 .unsqueeze(1)?
365 .repeat((1, w as usize))?;
366 hpos_ids = hpos_ids.reshape((
367 h as usize / self.spatial_merge_size,
368 self.spatial_merge_size,
369 w as usize / self.spatial_merge_size,
370 self.spatial_merge_size,
371 ))?;
372 hpos_ids = hpos_ids.permute((0, 2, 1, 3))?;
373 hpos_ids = hpos_ids.flatten_all()?;
374
375 let mut wpos_ids = Tensor::arange(0, w, device)?
376 .unsqueeze(0)?
377 .repeat((h as usize, 1))?;
378 wpos_ids = wpos_ids.reshape((
379 h as usize / self.spatial_merge_size,
380 self.spatial_merge_size,
381 w as usize / self.spatial_merge_size,
382 self.spatial_merge_size,
383 ))?;
384 wpos_ids = wpos_ids.permute((0, 2, 1, 3))?;
385 wpos_ids = wpos_ids.flatten_all()?;
386
387 pos_ids.push(Tensor::stack(&[hpos_ids, wpos_ids], D::Minus1)?.repeat((t as usize, 1))?);
388 }
389 let pos_ids = Tensor::cat(&pos_ids, 0)?;
390 let max_grid_size = grid_thw.i((.., 1..))?.max(0)?.max(0)?.to_scalar::<u32>()?;
391 let rotary_pos_emb_full = self.rotary_pos_emb.make_embeds(max_grid_size as usize)?;
392
393 assert_eq!(pos_ids.rank(), 2);
394 rotary_pos_emb_full
395 .index_select(&pos_ids.flatten_all()?, 0)?
396 .reshape((pos_ids.dim(0)?, pos_ids.dim(1)?, ()))?
397 .flatten_from(1)
398 }
399
400 fn get_window_index(&self, grid_thw: &Tensor, device: &Device) -> Result<(Tensor, Vec<i64>)> {
401 const PADDING_VALUE: i32 = -100;
402 let mut window_index = Vec::new();
403 let mut cu_window_seqlens = vec![0];
404 let mut window_index_id = 0;
405 let vit_merger_window_size = self.window_size / self.spatial_merge_size / self.patch_size;
406
407 for i_thw in grid_thw.to_vec2::<u32>()? {
408 let (t, h, w) = (i_thw[0] as usize, i_thw[1] as usize, i_thw[2] as usize);
409 let llm_grid_h = h / self.spatial_merge_size;
410 let llm_grid_w = w / self.spatial_merge_size;
411 let index = Tensor::arange(0i32, (t * llm_grid_h * llm_grid_w) as i32, &Device::Cpu)?
412 .reshape((t, llm_grid_h, llm_grid_w))?;
413 let pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size;
414 let pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size;
415 let num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size;
416 let num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size;
417 let index_padded = {
418 let h = Tensor::full(PADDING_VALUE, (t, pad_h, llm_grid_w), &Device::Cpu)?;
419 let w = Tensor::full(PADDING_VALUE, (t, pad_h + llm_grid_h, pad_w), &Device::Cpu)?;
420 let mut index = Tensor::cat(&[index, h], D::Minus2)?;
421 index = Tensor::cat(&[index, w], D::Minus1)?;
422 index = index.reshape((
423 t,
424 num_windows_h,
425 vit_merger_window_size,
426 num_windows_w,
427 vit_merger_window_size,
428 ))?;
429 index = index.permute((0, 1, 3, 2, 4))?.reshape((
430 t,
431 num_windows_h * num_windows_w,
432 vit_merger_window_size,
433 vit_merger_window_size,
434 ))?;
435 index
436 };
437 let seqlens = index_padded
438 .ne(PADDING_VALUE)?
439 .to_dtype(index_padded.dtype())?
440 .sum((2, 3))?
441 .flatten_all()?;
442 let index_new = index_padded
443 .flatten_all()?
444 .to_vec1::<i32>()?
445 .into_iter()
446 .filter(|x| *x != PADDING_VALUE)
447 .collect::<Vec<_>>();
448 window_index.push(Tensor::new(
449 index_new
450 .iter()
451 .map(|x| x + window_index_id)
452 .collect::<Vec<_>>(),
453 device,
454 )?);
455 let cu_seqlens_tmp = ((seqlens
456 .to_dtype(DType::F32)?
457 .cumsum(0)?
458 .to_dtype(seqlens.dtype())?
459 * self.spatial_merge_unit as f64)?
460 + cu_window_seqlens[cu_window_seqlens.len() - 1] as f64)?;
461 cu_window_seqlens.extend(cu_seqlens_tmp.to_vec1::<i64>()?);
462 window_index_id += (t * llm_grid_h * llm_grid_w) as i32;
463 }
464
465 Ok((Tensor::cat(&window_index, 0)?, cu_window_seqlens))
466 }
467
468 pub fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result<Tensor> {
469 let xs = self
470 .patch_embed
471 .forward(&xs.to_dtype(self.patch_merger.mlp0.weight().dtype())?)?;
472 let rotary_pos_emb = self.rot_pos_emb(grid_thw, xs.device())?;
473 let (window_index, mut cu_window_seqlens) = self.get_window_index(grid_thw, xs.device())?;
474 cu_window_seqlens.dedup();
475
476 let seq_len = xs.dims2()?.0;
477 let mut xs = xs.reshape((
478 seq_len / self.spatial_merge_unit,
479 self.spatial_merge_unit,
480 (),
481 ))?;
482 xs = xs.index_select(&window_index, 0)?;
483 xs = xs.reshape((seq_len, ()))?;
484 let mut rotary_pos_emb = rotary_pos_emb.reshape((
485 seq_len / self.spatial_merge_unit,
486 self.spatial_merge_unit,
487 (),
488 ))?;
489 rotary_pos_emb = rotary_pos_emb.index_select(&window_index, 0)?;
490 rotary_pos_emb = rotary_pos_emb.reshape((seq_len, ()))?;
491 rotary_pos_emb = Tensor::cat(&[&rotary_pos_emb; 2], D::Minus1)?;
492 rotary_pos_emb = rotary_pos_emb.to_dtype(xs.dtype())?;
493
494 let grid_thw = grid_thw.to_device(&Device::Cpu)?;
495 let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
496 .repeat_interleave_flat(grid_thw.i((.., 0))?.to_vec1::<u32>()?)?
497 .to_dtype(DType::F32)?
498 .cumsum(0)?
499 .to_dtype(DType::U32)?
500 .pad_with_zeros(0, 1, 0)?
501 .to_vec1::<u32>()?;
502
503 let seq_len = xs.dim(0)?;
504 let attention_mask_full = match &cu_seqlens[..] {
505 &[0, len] if len == seq_len as u32 => None,
506 cu_seqlens => {
507 let mut attention_mask =
508 Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
509 .to_dtype(xs.dtype())?;
510 for i in 1..cu_seqlens.len() {
511 let a = cu_seqlens[i - 1] as usize;
512 let b = cu_seqlens[i] as usize;
513 attention_mask = attention_mask.slice_assign(
514 &[&.., &(a..b), &(a..b)],
515 &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
516 )?;
517 }
518 Some(attention_mask)
519 }
520 };
521 let attention_mask_window = match &cu_window_seqlens[..] {
522 &[0, len] if len == seq_len as i64 => None,
523 cu_seqlens => {
524 let mut attention_mask =
525 Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
526 .to_dtype(xs.dtype())?;
527 for i in 1..cu_seqlens.len() {
528 let a = cu_seqlens[i - 1] as usize;
529 let b = cu_seqlens[i] as usize;
530 attention_mask = attention_mask.slice_assign(
531 &[&.., &(a..b), &(a..b)],
532 &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
533 )?;
534 }
535 Some(attention_mask)
536 }
537 };
538
539 for (i, blk) in self.blocks.iter().enumerate() {
540 let attention_mask = if self.fullatt_block_indices.contains(&i) {
541 attention_mask_full.as_ref()
542 } else {
543 attention_mask_window.as_ref()
544 };
545 xs = blk.forward(&xs, attention_mask, &rotary_pos_emb)?;
546 }
547
548 xs = self.patch_merger.forward(&xs)?;
549 let reverse_indices = window_index.arg_sort_last_dim(true)?;
550 xs.index_select(&reverse_indices, 0)
551 }
552}