1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use text::Qwen2_5VLTextModel;
8use vision::Qwen2_5VLVisionModel;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 layers::CausalMasker,
14 layers_masker::{masked_fill, PastKvLenCache},
15 paged_attention::{AttentionImplementation, ModelConfigMetadata},
16 pipeline::{
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19 },
20};
21
22mod config;
23mod inputs_processor;
24mod text;
25mod vision;
26
27pub(crate) use config::Config;
28pub(crate) use inputs_processor::Qwen2_5VLProcessor;
29
30pub struct Qwen2_5VLModel {
31 text: Qwen2_5VLTextModel,
32 vision: Qwen2_5VLVisionModel,
33 spatial_merge_size: usize,
34 image_token_id: u32,
35 video_token_id: u32,
36}
37
38impl Qwen2_5VLModel {
39 pub fn new(
40 cfg: &Config,
41 vb: ShardedVarBuilder,
42 is_gptx: bool,
43 normal_loading_metadata: NormalLoadingMetadata,
44 attention_mechanism: AttentionImplementation,
45 ) -> Result<Self> {
46 if cfg.use_sliding_window {
47 candle_core::bail!("Sliding window is unsupported for now!");
49 }
50 let vision = Qwen2_5VLVisionModel::new(
51 &cfg.vision_config,
52 vb.pp("visual")
53 .set_device(normal_loading_metadata.real_device.clone()),
54 &normal_loading_metadata.mapper.get_comm_for(0)?,
55 )?;
56 let text = Qwen2_5VLTextModel::new(
57 cfg,
58 vb.clone(),
59 is_gptx,
60 normal_loading_metadata,
61 attention_mechanism,
62 )?;
63 Ok(Self {
64 text,
65 vision,
66 spatial_merge_size: cfg.vision_config.spatial_merge_size,
67 image_token_id: cfg.image_token_id,
68 video_token_id: cfg.video_token_id,
69 })
70 }
71
72 #[allow(clippy::too_many_arguments)]
73 fn get_rope_index(
75 &self,
76 input_ids: &Tensor,
77 image_grid_thw: Option<&Tensor>,
78 video_grid_thw: Option<&Tensor>,
79 attention_mask: Option<&Tensor>,
80 attention_mask_indices: Option<&Tensor>,
81 input_ids_searching: Vec<Vec<u32>>,
82 image_nums: Vec<usize>,
83 video_nums: Vec<usize>,
84 ) -> Result<(Tensor, Tensor)> {
85 if image_grid_thw.is_some() || video_grid_thw.is_some() {
86 let total_input_ids = input_ids.clone();
87 let mut position_ids = Tensor::zeros(
88 (3, input_ids.dim(0)?, input_ids.dim(1)?),
89 DType::I64,
90 input_ids.device(),
91 )?;
92 let mut mrope_position_deltas = Vec::new();
93
94 let mut image_index = 0;
95 let mut video_index = 0;
96 for (i, mut input_ids) in total_input_ids
97 .chunk(input_ids.dim(0)?, 0)?
98 .into_iter()
99 .enumerate()
100 {
101 if let Some(attention_mask_indices) = attention_mask_indices {
102 input_ids = input_ids
103 .i(i)?
104 .to_dtype(DType::F32)?
105 .index_select(&attention_mask_indices.squeeze(0)?, 0)?
106 .to_dtype(input_ids.dtype())?;
107 }
108 let image_nums = image_nums[i];
109 let vision_nums = video_nums[i];
110
111 let mut llm_pos_ids: Vec<Tensor> = Vec::new();
112 let mut max_last_llm_pos_ids = None;
113 let mut max_llm_pos_ids = 0;
114 let mut st = 0;
115 let (mut remain_images, mut remain_videos) = (image_nums, vision_nums);
116 for _ in 0..(image_nums + vision_nums) {
117 let ed_image = if input_ids_searching[i].contains(&self.image_token_id)
118 && remain_images > 0
119 {
120 input_ids_searching[i][st..]
121 .iter()
122 .position(|&t| t == self.image_token_id)
123 .unwrap()
124 + st
125 } else {
126 input_ids.dim(0)? + 1
127 };
128 let ed_video = if input_ids_searching[i].contains(&self.video_token_id)
129 && remain_videos > 0
130 {
131 input_ids_searching[i][st..]
132 .iter()
133 .position(|&t| t == self.video_token_id)
134 .unwrap()
135 + st
136 } else {
137 input_ids.dim(0)? + 1
138 };
139 let (ed, llm_grid_t, h, w) = if ed_image < ed_video {
140 let t = image_grid_thw.as_ref().unwrap().i((image_index, 0))?;
141 let h = image_grid_thw.as_ref().unwrap().i((image_index, 1))?;
142 let w = image_grid_thw.as_ref().unwrap().i((image_index, 2))?;
143 image_index += 1;
144 remain_images -= 1;
145 (
146 ed_image,
147 t.to_scalar::<u32>()?,
148 h.to_scalar::<u32>()?,
149 w.to_scalar::<u32>()?,
150 )
151 } else {
152 let t = video_grid_thw.as_ref().unwrap().i((video_index, 0))?;
153 let h = video_grid_thw.as_ref().unwrap().i((video_index, 1))?;
154 let w = video_grid_thw.as_ref().unwrap().i((video_index, 2))?;
155 video_index += 1;
156 remain_videos -= 1;
157 (
158 ed_video,
159 t.to_scalar::<u32>()?,
160 h.to_scalar::<u32>()?,
161 w.to_scalar::<u32>()?,
162 )
163 };
164 let llm_grid_h = h / self.spatial_merge_size as u32;
165 let llm_grid_w = w / self.spatial_merge_size as u32;
166 let text_len = ed - st;
167
168 let st_idx = max_last_llm_pos_ids.unwrap_or(0);
169 max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
171 llm_pos_ids.push(
172 Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
173 .unsqueeze(0)?
174 .repeat((3, 1))?,
175 );
176
177 let t_idx = Tensor::arange(0, llm_grid_t as i64, input_ids.device())?
178 .reshape(((), 1))?
179 .repeat((1, llm_grid_h as usize * llm_grid_w as usize))?
180 .flatten_all()?;
181 let h_idx = Tensor::arange(0, llm_grid_h as i64, input_ids.device())?
182 .reshape((1, (), 1))?
183 .repeat((llm_grid_t as usize, 1, llm_grid_w as usize))?
184 .flatten_all()?;
185 let w_idx = Tensor::arange(0, llm_grid_w as i64, input_ids.device())?
186 .reshape((1, 1, ()))?
187 .repeat((llm_grid_t as usize, llm_grid_h as usize, 1))?
188 .flatten_all()?;
189 max_last_llm_pos_ids = Some(
190 *[llm_grid_t, llm_grid_h, llm_grid_w].iter().max().unwrap() as i64
191 + (text_len as i64 + st_idx),
192 );
193 max_llm_pos_ids = max_llm_pos_ids.max(max_last_llm_pos_ids.unwrap());
194 llm_pos_ids.push(
195 (Tensor::stack(&[t_idx, h_idx, w_idx], 0)?.to_dtype(DType::F32)?
196 + (text_len + st_idx as usize) as f64)?
197 .to_dtype(DType::I64)?,
198 );
199 st = ed + (llm_grid_t * llm_grid_h * llm_grid_w) as usize;
200 }
201
202 if st < input_ids.dim(0)? {
203 let st_idx = max_last_llm_pos_ids.unwrap_or(0);
204 let text_len = (input_ids.dim(0)? - st) as u32;
205 max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
207 llm_pos_ids.push(
208 Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
209 .reshape((1, ()))?
210 .repeat((3, 1))?,
211 );
212 }
213
214 let llm_positions = Tensor::cat(&llm_pos_ids, 1)?.reshape((3, ()))?;
215 let positions_mask = attention_mask
216 .as_ref()
217 .unwrap()
218 .i(i)?
219 .eq(1f64)?
220 .unsqueeze(0)?
221 .repeat((3, 1))?;
222
223 position_ids = position_ids.slice_assign(
224 &[&.., &i, &..],
225 &positions_mask
226 .where_cond(&llm_positions, &position_ids.i((.., i, ..))?)?
227 .unsqueeze(1)?,
228 )?;
229 mrope_position_deltas
230 .push(max_llm_pos_ids + 1 - total_input_ids.i(i)?.dim(0)? as i64);
231 }
232 let mrope_position_deltas_len = mrope_position_deltas.len();
233 let mrope_position_deltas = Tensor::from_vec(
234 mrope_position_deltas,
235 (mrope_position_deltas_len,),
236 input_ids.device(),
237 )?
238 .unsqueeze(1)?;
239 Ok((position_ids, mrope_position_deltas))
240 } else if let Some(attention_mask) = attention_mask {
241 let position_ids = (attention_mask.to_dtype(DType::F32)?.cumsum(D::Minus1)? - 1f64)?;
242 let position_ids = masked_fill(&position_ids, &attention_mask.eq(0f64)?, 1i64)?;
243 let position_ids = position_ids.unsqueeze(0)?.repeat((3, 1, 1))?;
244
245 let max_position_ids = position_ids.max(0)?.max_keepdim(D::Minus1)?;
246 let mrope_position_deltas =
247 ((max_position_ids + 1.)? - attention_mask.dim(D::Minus1)? as f64)?;
248
249 Ok((
250 position_ids.to_dtype(DType::I64)?,
251 mrope_position_deltas.to_dtype(DType::I64)?,
252 ))
253 } else {
254 let position_ids = Tensor::arange(0i64, input_ids.dim(1)? as i64, input_ids.device())?
255 .reshape((1, 1, ()))?
256 .repeat((3, input_ids.dim(0)?, 1))?;
257 let mrope_position_deltas =
258 Tensor::zeros((input_ids.dim(0)?, 1), DType::I64, input_ids.device())?;
259
260 Ok((position_ids, mrope_position_deltas))
261 }
262 }
263
264 #[allow(clippy::too_many_arguments)]
265 pub fn forward(
266 &self,
267 input_ids: &Tensor,
268 input_ids_full: &Tensor,
269 pixel_values: Option<Tensor>,
270 pixel_values_videos: Option<Tensor>,
271 image_grid_thw: Option<Tensor>,
272 video_grid_thw: Option<Tensor>,
273 seqlens: Vec<usize>,
274 continuous_img_pad: Vec<Vec<(usize, usize)>>,
275 continuous_vid_pad: Vec<Vec<(usize, usize)>>,
276 input_ids_searching: Vec<Vec<u32>>,
277 image_nums: Vec<usize>,
278 video_nums: Vec<usize>,
279 seqlen_offsets: &[usize],
280 context_lens: Vec<(usize, usize)>,
281 flash_params: &FlashParams,
282 ) -> Result<Tensor> {
283 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
284 input_ids,
285 &seqlen_offsets as &dyn PastKvLenCache,
286 self.text.cfg.sliding_window,
287 self.text.dtype,
288 self.text.cfg.num_attn_heads,
289 )?;
290
291 let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() {
292 let mut xs = self.text.embed_tokens(input_ids)?;
293
294 if let Some(pixel_values) = pixel_values {
295 let image_embeds = self
296 .vision
297 .forward(
298 &pixel_values,
299 image_grid_thw
300 .as_ref()
301 .context("pixel_values require image_grid_thw")?,
302 )?
303 .to_dtype(self.text.dtype)?;
304
305 for (batch, batch_ids) in continuous_img_pad.into_iter().enumerate() {
306 let mut last_end = 0;
307 for (start, end) in batch_ids {
308 xs = xs.slice_assign(
309 &[&batch, &(start..end), &..],
310 &image_embeds
311 .i((last_end..last_end + (end - start), ..))?
312 .unsqueeze(0)?,
313 )?;
314 last_end = end - start;
315 }
316 }
317 }
318
319 if let Some(pixel_values_videos) = pixel_values_videos {
320 let video_embeds = self.vision.forward(
321 &pixel_values_videos,
322 video_grid_thw
323 .as_ref()
324 .context("pixel_values_videos require video_grid_thw")?,
325 )?;
326
327 for (batch, batch_ids) in continuous_vid_pad.into_iter().enumerate() {
328 let mut last_end = 0;
329 for (start, end) in batch_ids {
330 xs = xs.slice_assign(
331 &[&batch, &(start..end), &..],
332 &video_embeds
333 .i((last_end..last_end + (end - start), ..))?
334 .unsqueeze(0)?,
335 )?;
336 last_end = end - start;
337 }
338 }
339 }
340
341 xs
342 } else {
343 self.text.embed_tokens(input_ids)?
344 };
345
346 let mut ropeidx_attn_mask_bs = Vec::new();
347 let max_seqlens = *seqlens.iter().max().unwrap();
348 for len in &seqlens {
349 ropeidx_attn_mask_bs.push(Tensor::new(
350 [vec![1f32; *len], vec![0f32; max_seqlens - len]].concat(),
351 input_ids.device(),
352 )?);
353 }
354 let ropeidx_attn_mask = Tensor::stack(&ropeidx_attn_mask_bs, 0)?;
355 let mut ropeidx_attn_mask_indices_bs = Vec::new();
356 for len in seqlens {
357 ropeidx_attn_mask_indices_bs.push(Tensor::from_vec(
358 (0..len as i64).collect(),
359 (len,),
360 input_ids.device(),
361 )?);
362 }
363 let ropeidx_attn_mask_indices = Tensor::stack(&ropeidx_attn_mask_indices_bs, 0)?;
364
365 let ropeidx_input_ids = if attention_mask.is_some() {
366 input_ids
367 } else {
368 input_ids_full
369 };
370 let (position_ids, mrope_position_deltas) = self.get_rope_index(
371 ropeidx_input_ids,
372 image_grid_thw.as_ref(),
373 video_grid_thw.as_ref(),
374 Some(&ropeidx_attn_mask),
375 Some(&ropeidx_attn_mask_indices),
376 input_ids_searching,
377 image_nums,
378 video_nums,
379 )?;
380
381 let position_ids = if attention_mask.is_some() {
382 position_ids
383 } else {
384 let mut position_ids = Tensor::new(
385 seqlen_offsets.iter().map(|x| *x as i64).collect::<Vec<_>>(),
386 input_ids.device(),
387 )?
388 .reshape((1, (), 1))?
389 .repeat((3, 1, 1))?;
390
391 position_ids = position_ids.broadcast_add(&mrope_position_deltas.unsqueeze(0)?)?;
392
393 position_ids
394 };
395
396 let out = self.text.forward_embeds(
397 input_embeds,
398 attention_mask.as_ref(),
399 &position_ids,
400 context_lens,
401 flash_params,
402 )?;
403 Ok(out)
404 }
405}
406
407pub(crate) struct Qwen2_5VLVisionSpecificArgs {
408 input_ids_full: Tensor,
409 image_grid_thw: Option<Tensor>, video_grid_thw: Option<Tensor>, seqlens: Vec<usize>,
412 continuous_img_pad: Vec<Vec<(usize, usize)>>,
413 continuous_vid_pad: Vec<Vec<(usize, usize)>>,
414 input_ids_searching: Vec<Vec<u32>>,
415 image_nums: Vec<usize>,
416 video_nums: Vec<usize>,
417}
418
419impl VisionModel for Qwen2_5VLModel {
420 fn forward(
421 &self,
422 input_ids: &Tensor,
423 pixel_values: Option<Tensor>,
424 seqlen_offsets: &[usize],
425 context_lens: Vec<(usize, usize)>,
426 _position_ids: Vec<usize>,
427 model_specific_args: Box<dyn Any>,
428 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
429 flash_params: &FlashParams,
430 ) -> Result<Tensor> {
431 let Qwen2_5VLVisionSpecificArgs {
432 input_ids_full,
433 image_grid_thw,
434 video_grid_thw,
435 seqlens,
436 continuous_img_pad,
437 continuous_vid_pad,
438 input_ids_searching,
439 image_nums,
440 video_nums,
441 } = *model_specific_args
442 .downcast()
443 .expect("Cannot downcast into `Qwen2_5VLVisionSpecificArgs`");
444 let (pixel_values, pixel_values_video) = match (&image_grid_thw, &video_grid_thw) {
445 (Some(_), None) => (pixel_values, None),
446 (None, Some(_)) => (None, pixel_values),
447 (None, None) => (None, None),
448 (Some(_), Some(_)) => {
449 candle_core::bail!("Images and videos cannot be provided together.")
450 }
451 };
452 self.forward(
453 input_ids,
454 &input_ids_full,
455 pixel_values,
456 pixel_values_video,
457 image_grid_thw,
458 video_grid_thw,
459 seqlens,
460 continuous_img_pad,
461 continuous_vid_pad,
462 input_ids_searching,
463 image_nums,
464 video_nums,
465 seqlen_offsets,
466 context_lens,
467 flash_params,
468 )
469 }
470 fn cache(&self) -> &EitherCache {
471 &self.text.cache
472 }
473 fn cache_mut(&mut self) -> &mut EitherCache {
474 &mut self.text.cache
475 }
476 fn device(&self) -> &Device {
477 &self.text.device
478 }
479 fn max_seq_len(&self) -> usize {
480 self.text.max_seq_len
481 }
482 fn has_conv2d(&self) -> bool {
483 true
484 }
485 fn config(&self) -> &ModelConfigMetadata {
486 &self.text.cfg
487 }
488 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any> {
489 assert_eq!(input_ids.dims()[0], 1);
490 Box::new(Qwen2_5VLVisionSpecificArgs {
491 input_ids_full: input_ids.clone(),
492 image_grid_thw: None,
493 video_grid_thw: None,
494 seqlens: vec![input_ids.dims()[1]],
495 continuous_img_pad: vec![],
496 continuous_vid_pad: vec![],
497 input_ids_searching: vec![vec![]; input_ids.dims()[0]],
498 image_nums: vec![0; input_ids.dims()[0]],
499 video_nums: vec![0; input_ids.dims()[0]],
500 })
501 }
502}
503
504impl IsqModel for Qwen2_5VLModel {
505 fn get_layers(
506 &mut self,
507 ) -> (
508 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
509 &dyn DeviceMapper,
510 ) {
511 self.text.get_layers()
512 }
513 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
514 self.text.residual_tensors()
515 }
516}
517
518impl AnyMoeBaseModelMixin for Qwen2_5VLModel {}