1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use text::Qwen2VLTextModel;
8use vision::Qwen2VLVisionModel;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 layers::CausalMasker,
14 layers_masker::{masked_fill, PastKvLenCache},
15 paged_attention::{AttentionImplementation, ModelConfigMetadata},
16 pipeline::{
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19 },
20};
21
22mod config;
23mod inputs_processor;
24mod text;
25mod vision;
26
27pub(crate) use config::Config;
28pub(crate) use inputs_processor::Qwen2VLProcessor;
29
30pub struct Qwen2VLModel {
31 text: Qwen2VLTextModel,
32 vision: Qwen2VLVisionModel,
33 spatial_merge_size: usize,
34 image_token_id: u32,
35 video_token_id: u32,
36}
37
38impl Qwen2VLModel {
39 pub fn new(
40 cfg: &Config,
41 vb: ShardedVarBuilder,
42 is_gptx: bool,
43 normal_loading_metadata: NormalLoadingMetadata,
44 attention_mechanism: AttentionImplementation,
45 ) -> Result<Self> {
46 if cfg.use_sliding_window {
47 candle_core::bail!("Sliding window is unsupported for now!");
49 }
50 let vision_vb = if vb.contains_tensor("vision_tower.patch_embed.proj.weight") {
52 vb.pp("vision_tower")
53 } else {
54 vb.pp("visual")
55 };
56 let vision = Qwen2VLVisionModel::new(
57 &cfg.vision_config,
58 vision_vb.set_device(normal_loading_metadata.real_device.clone()),
59 &normal_loading_metadata.mapper.get_comm_for(0)?,
60 )?;
61 let text = Qwen2VLTextModel::new(
62 cfg,
63 vb.clone(),
64 is_gptx,
65 normal_loading_metadata,
66 attention_mechanism,
67 )?;
68 Ok(Self {
69 text,
70 vision,
71 spatial_merge_size: cfg.vision_config.spatial_merge_size,
72 image_token_id: cfg.image_token_id,
73 video_token_id: cfg.video_token_id,
74 })
75 }
76
77 #[allow(clippy::too_many_arguments)]
78 fn get_rope_index(
80 &self,
81 input_ids: &Tensor,
82 image_grid_thw: Option<&Tensor>,
83 video_grid_thw: Option<&Tensor>,
84 attention_mask: Option<&Tensor>,
85 attention_mask_indices: Option<&Tensor>,
86 input_ids_searching: Vec<Vec<u32>>,
87 image_nums: Vec<usize>,
88 video_nums: Vec<usize>,
89 ) -> Result<(Tensor, Tensor)> {
90 if image_grid_thw.is_some() || video_grid_thw.is_some() {
91 let total_input_ids = input_ids.clone();
92 let mut position_ids = Tensor::zeros(
93 (3, input_ids.dim(0)?, input_ids.dim(1)?),
94 DType::I64,
95 input_ids.device(),
96 )?;
97 let mut mrope_position_deltas = Vec::new();
98
99 let mut image_index = 0;
100 let mut video_index = 0;
101 for (i, mut input_ids) in total_input_ids
102 .chunk(input_ids.dim(0)?, 0)?
103 .into_iter()
104 .enumerate()
105 {
106 if let Some(attention_mask_indices) = attention_mask_indices {
107 input_ids = input_ids
108 .i(i)?
109 .to_dtype(DType::F32)?
110 .index_select(&attention_mask_indices.squeeze(0)?, 0)?
111 .to_dtype(input_ids.dtype())?;
112 }
113 let image_nums = image_nums[i];
114 let vision_nums = video_nums[i];
115
116 let mut llm_pos_ids: Vec<Tensor> = Vec::new();
117 let mut max_last_llm_pos_ids = None;
118 let mut max_llm_pos_ids = 0;
119 let mut st = 0;
120 let (mut remain_images, mut remain_videos) = (image_nums, vision_nums);
121 for _ in 0..(image_nums + vision_nums) {
122 let ed_image = if input_ids_searching[i].contains(&self.image_token_id)
123 && remain_images > 0
124 {
125 input_ids_searching[i][st..]
126 .iter()
127 .position(|&t| t == self.image_token_id)
128 .unwrap()
129 + st
130 } else {
131 input_ids.dim(0)? + 1
132 };
133 let ed_video = if input_ids_searching[i].contains(&self.video_token_id)
134 && remain_videos > 0
135 {
136 input_ids_searching[i][st..]
137 .iter()
138 .position(|&t| t == self.video_token_id)
139 .unwrap()
140 + st
141 } else {
142 input_ids.dim(0)? + 1
143 };
144 let (ed, llm_grid_t, h, w) = if ed_image < ed_video {
145 let t = image_grid_thw.as_ref().unwrap().i((image_index, 0))?;
146 let h = image_grid_thw.as_ref().unwrap().i((image_index, 1))?;
147 let w = image_grid_thw.as_ref().unwrap().i((image_index, 2))?;
148 image_index += 1;
149 remain_images -= 1;
150 (
151 ed_image,
152 t.to_scalar::<u32>()?,
153 h.to_scalar::<u32>()?,
154 w.to_scalar::<u32>()?,
155 )
156 } else {
157 let t = video_grid_thw.as_ref().unwrap().i((video_index, 0))?;
158 let h = video_grid_thw.as_ref().unwrap().i((video_index, 1))?;
159 let w = video_grid_thw.as_ref().unwrap().i((video_index, 2))?;
160 video_index += 1;
161 remain_videos -= 1;
162 (
163 ed_video,
164 t.to_scalar::<u32>()?,
165 h.to_scalar::<u32>()?,
166 w.to_scalar::<u32>()?,
167 )
168 };
169 let llm_grid_h = h / self.spatial_merge_size as u32;
170 let llm_grid_w = w / self.spatial_merge_size as u32;
171 let text_len = ed - st;
172
173 let st_idx = max_last_llm_pos_ids.unwrap_or(0);
174 max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
176 llm_pos_ids.push(
177 Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
178 .unsqueeze(0)?
179 .repeat((3, 1))?,
180 );
181
182 let t_idx = Tensor::arange(0, llm_grid_t as i64, input_ids.device())?
183 .reshape(((), 1))?
184 .repeat((1, llm_grid_h as usize * llm_grid_w as usize))?
185 .flatten_all()?;
186 let h_idx = Tensor::arange(0, llm_grid_h as i64, input_ids.device())?
187 .reshape((1, (), 1))?
188 .repeat((llm_grid_t as usize, 1, llm_grid_w as usize))?
189 .flatten_all()?;
190 let w_idx = Tensor::arange(0, llm_grid_w as i64, input_ids.device())?
191 .reshape((1, 1, ()))?
192 .repeat((llm_grid_t as usize, llm_grid_h as usize, 1))?
193 .flatten_all()?;
194 max_last_llm_pos_ids = Some(
195 *[llm_grid_t, llm_grid_h, llm_grid_w].iter().max().unwrap() as i64
196 + (text_len as i64 + st_idx),
197 );
198 max_llm_pos_ids = max_llm_pos_ids.max(max_last_llm_pos_ids.unwrap());
199 llm_pos_ids.push(
200 (Tensor::stack(&[t_idx, h_idx, w_idx], 0)?.to_dtype(DType::F32)?
201 + (text_len + st_idx as usize) as f64)?
202 .to_dtype(DType::I64)?,
203 );
204 st = ed + (llm_grid_t * llm_grid_h * llm_grid_w) as usize;
205 }
206
207 if st < input_ids.dim(0)? {
208 let st_idx = max_last_llm_pos_ids.unwrap_or(0);
209 let text_len = (input_ids.dim(0)? - st) as u32;
210 max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
212 llm_pos_ids.push(
213 Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
214 .reshape((1, ()))?
215 .repeat((3, 1))?,
216 );
217 }
218
219 let llm_positions = Tensor::cat(&llm_pos_ids, 1)?.reshape((3, ()))?;
220 let positions_mask = attention_mask
221 .as_ref()
222 .unwrap()
223 .i(i)?
224 .eq(1f64)?
225 .unsqueeze(0)?
226 .repeat((3, 1))?;
227
228 position_ids = position_ids.slice_assign(
229 &[0..position_ids.dim(0)?, i..i + 1, 0..position_ids.dim(2)?],
230 &positions_mask
231 .where_cond(&llm_positions, &position_ids.i((.., i, ..))?)?
232 .unsqueeze(1)?,
233 )?;
234 mrope_position_deltas
235 .push(max_llm_pos_ids + 1 - total_input_ids.i(i)?.dim(0)? as i64);
236 }
237 let mrope_position_deltas_len = mrope_position_deltas.len();
238 let mrope_position_deltas = Tensor::from_vec(
239 mrope_position_deltas,
240 (mrope_position_deltas_len,),
241 input_ids.device(),
242 )?
243 .unsqueeze(1)?;
244 Ok((position_ids, mrope_position_deltas))
245 } else if let Some(attention_mask) = attention_mask {
246 let position_ids = (attention_mask.to_dtype(DType::F32)?.cumsum(D::Minus1)? - 1f64)?;
247 let position_ids = masked_fill(&position_ids, &attention_mask.eq(0f64)?, 1i64)?;
248 let position_ids = position_ids.unsqueeze(0)?.repeat((3, 1, 1))?;
249
250 let max_position_ids = position_ids.max(0)?.max_keepdim(D::Minus1)?;
251 let mrope_position_deltas =
252 ((max_position_ids + 1.)? - attention_mask.dim(D::Minus1)? as f64)?;
253
254 Ok((
255 position_ids.to_dtype(DType::I64)?,
256 mrope_position_deltas.to_dtype(DType::I64)?,
257 ))
258 } else {
259 let position_ids = Tensor::arange(0i64, input_ids.dim(1)? as i64, input_ids.device())?
260 .reshape((1, 1, ()))?
261 .repeat((3, input_ids.dim(0)?, 1))?;
262 let mrope_position_deltas =
263 Tensor::zeros((input_ids.dim(0)?, 1), DType::I64, input_ids.device())?;
264
265 Ok((position_ids, mrope_position_deltas))
266 }
267 }
268
269 #[allow(clippy::too_many_arguments)]
270 pub fn forward(
271 &self,
272 input_ids: &Tensor,
273 input_ids_full: &Tensor,
274 pixel_values: Option<Tensor>,
275 pixel_values_videos: Option<Tensor>,
276 image_grid_thw: Option<Tensor>,
277 video_grid_thw: Option<Tensor>,
278 seqlens: Vec<usize>,
279 continuous_img_pad: Vec<Vec<(usize, usize)>>,
280 continuous_vid_pad: Vec<Vec<(usize, usize)>>,
281 input_ids_searching: Vec<Vec<u32>>,
282 image_nums: Vec<usize>,
283 video_nums: Vec<usize>,
284 seqlen_offsets: &[usize],
285 context_lens: Vec<(usize, usize)>,
286 flash_params: &FlashParams,
287 ) -> Result<Tensor> {
288 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
289 input_ids,
290 &seqlen_offsets as &dyn PastKvLenCache,
291 self.text.cfg.sliding_window,
292 self.text.dtype,
293 self.text.cfg.num_attn_heads,
294 )?;
295
296 let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() {
297 let mut xs = self.text.embed_tokens(input_ids)?;
298
299 if let Some(pixel_values) = pixel_values {
300 let image_embeds = self
301 .vision
302 .forward(
303 &pixel_values,
304 image_grid_thw
305 .as_ref()
306 .context("pixel_values require image_grid_thw")?,
307 )?
308 .to_dtype(self.text.dtype)?;
309
310 for (batch, batch_ids) in continuous_img_pad.into_iter().enumerate() {
311 let mut last_end = 0;
312 for (start, end) in batch_ids {
313 xs = xs.slice_assign(
314 &[batch..batch + 1, start..end, 0..xs.dim(2)?],
315 &image_embeds
316 .i((last_end..last_end + (end - start), ..))?
317 .unsqueeze(0)?,
318 )?;
319 last_end = end - start;
320 }
321 }
322 }
323
324 if let Some(pixel_values_videos) = pixel_values_videos {
325 let video_embeds = self.vision.forward(
326 &pixel_values_videos,
327 video_grid_thw
328 .as_ref()
329 .context("pixel_values_videos require video_grid_thw")?,
330 )?;
331
332 for (batch, batch_ids) in continuous_vid_pad.into_iter().enumerate() {
333 let mut last_end = 0;
334 for (start, end) in batch_ids {
335 xs = xs.slice_assign(
336 &[batch..batch + 1, start..end, 0..xs.dim(2)?],
337 &video_embeds
338 .i((last_end..last_end + (end - start), ..))?
339 .unsqueeze(0)?,
340 )?;
341 last_end = end - start;
342 }
343 }
344 }
345
346 xs
347 } else {
348 self.text.embed_tokens(input_ids)?
349 };
350
351 let mut ropeidx_attn_mask_bs = Vec::new();
352 let max_seqlens = *seqlens.iter().max().unwrap();
353 for len in &seqlens {
354 ropeidx_attn_mask_bs.push(Tensor::new(
355 [vec![1f32; *len], vec![0f32; max_seqlens - len]].concat(),
356 input_ids.device(),
357 )?);
358 }
359 let ropeidx_attn_mask = Tensor::stack(&ropeidx_attn_mask_bs, 0)?;
360 let mut ropeidx_attn_mask_indices_bs = Vec::new();
361 for (len, offset) in seqlens.iter().zip(seqlen_offsets) {
362 ropeidx_attn_mask_indices_bs.push(Tensor::from_vec(
363 (*offset as i64..(*len as i64 + *offset as i64)).collect(),
364 (*len,),
365 input_ids.device(),
366 )?);
367 }
368 let ropeidx_attn_mask_indices = Tensor::stack(&ropeidx_attn_mask_indices_bs, 0)?;
369
370 let ropeidx_input_ids = if attention_mask.is_some() {
371 input_ids
372 } else {
373 input_ids_full
374 };
375 let (position_ids, mrope_position_deltas) = self.get_rope_index(
376 ropeidx_input_ids,
377 image_grid_thw.as_ref(),
378 video_grid_thw.as_ref(),
379 Some(&ropeidx_attn_mask),
380 Some(&ropeidx_attn_mask_indices),
381 input_ids_searching,
382 image_nums,
383 video_nums,
384 )?;
385
386 let position_ids = if attention_mask.is_some() {
387 position_ids
388 } else {
389 let mut position_ids = Tensor::new(
390 seqlen_offsets.iter().map(|x| *x as i64).collect::<Vec<_>>(),
391 input_ids.device(),
392 )?
393 .reshape((1, (), 1))?
394 .repeat((3, 1, 1))?;
395
396 position_ids = position_ids.broadcast_add(&mrope_position_deltas.unsqueeze(0)?)?;
397
398 position_ids
399 };
400
401 let out = self.text.forward_embeds(
402 input_embeds,
403 attention_mask.as_ref(),
404 &position_ids,
405 context_lens,
406 flash_params,
407 )?;
408 Ok(out)
409 }
410}
411
412pub(crate) struct Qwen2VLVisionSpecificArgs {
413 input_ids_full: Tensor,
414 image_grid_thw: Option<Tensor>, video_grid_thw: Option<Tensor>, seqlens: Vec<usize>,
417 continuous_img_pad: Vec<Vec<(usize, usize)>>,
418 continuous_vid_pad: Vec<Vec<(usize, usize)>>,
419 input_ids_searching: Vec<Vec<u32>>,
420 image_nums: Vec<usize>,
421 video_nums: Vec<usize>,
422}
423
424impl VisionModel for Qwen2VLModel {
425 fn forward(
426 &self,
427 input_ids: &Tensor,
428 pixel_values: Option<Tensor>,
429 seqlen_offsets: &[usize],
430 context_lens: Vec<(usize, usize)>,
431 _position_ids: Vec<usize>,
432 model_specific_args: Box<dyn Any>,
433 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
434 flash_params: &FlashParams,
435 ) -> Result<Tensor> {
436 let Qwen2VLVisionSpecificArgs {
437 input_ids_full,
438 image_grid_thw,
439 video_grid_thw,
440 seqlens,
441 continuous_img_pad,
442 continuous_vid_pad,
443 input_ids_searching,
444 image_nums,
445 video_nums,
446 } = *model_specific_args
447 .downcast()
448 .expect("Cannot downcast into `Qwen2VLVisionSpecificArgs`");
449 let (pixel_values, pixel_values_video) = match (&image_grid_thw, &video_grid_thw) {
450 (Some(_), None) => (pixel_values, None),
451 (None, Some(_)) => (None, pixel_values),
452 (None, None) => (None, None),
453 (Some(_), Some(_)) => {
454 candle_core::bail!("Images and videos cannot be provided together.")
455 }
456 };
457 self.forward(
458 input_ids,
459 &input_ids_full,
460 pixel_values,
461 pixel_values_video,
462 image_grid_thw,
463 video_grid_thw,
464 seqlens,
465 continuous_img_pad,
466 continuous_vid_pad,
467 input_ids_searching,
468 image_nums,
469 video_nums,
470 seqlen_offsets,
471 context_lens,
472 flash_params,
473 )
474 }
475 fn cache(&self) -> &EitherCache {
476 &self.text.cache
477 }
478 fn cache_mut(&mut self) -> &mut EitherCache {
479 &mut self.text.cache
480 }
481 fn device(&self) -> &Device {
482 &self.text.device
483 }
484 fn max_seq_len(&self) -> usize {
485 self.text.max_seq_len
486 }
487 fn config(&self) -> &ModelConfigMetadata {
488 &self.text.cfg
489 }
490 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any> {
491 assert_eq!(input_ids.dims()[0], 1);
492 Box::new(Qwen2VLVisionSpecificArgs {
493 input_ids_full: input_ids.clone(),
494 image_grid_thw: None,
495 video_grid_thw: None,
496 seqlens: vec![input_ids.dims()[1]],
497 continuous_img_pad: vec![],
498 continuous_vid_pad: vec![],
499 input_ids_searching: vec![vec![]; input_ids.dims()[0]],
500 image_nums: vec![0; input_ids.dims()[0]],
501 video_nums: vec![0; input_ids.dims()[0]],
502 })
503 }
504}
505
506impl IsqModel for Qwen2VLModel {
507 fn get_layers(
508 &mut self,
509 ) -> (
510 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
511 &dyn DeviceMapper,
512 ) {
513 self.text.get_layers()
514 }
515 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
516 self.text.residual_tensors()
517 }
518}
519
520impl AnyMoeBaseModelMixin for Qwen2VLModel {}