mistralrs_core/vision_models/qwen2vl/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use text::Qwen2VLTextModel;
8use vision::Qwen2VLVisionModel;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    layers::CausalMasker,
14    layers_masker::{masked_fill, PastKvLenCache},
15    paged_attention::{AttentionImplementation, ModelConfigMetadata},
16    pipeline::{
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19    },
20};
21
22mod config;
23mod inputs_processor;
24mod text;
25mod vision;
26
27pub(crate) use config::Config;
28pub(crate) use inputs_processor::Qwen2VLProcessor;
29
30pub struct Qwen2VLModel {
31    text: Qwen2VLTextModel,
32    vision: Qwen2VLVisionModel,
33    spatial_merge_size: usize,
34    image_token_id: u32,
35    video_token_id: u32,
36}
37
38impl Qwen2VLModel {
39    pub fn new(
40        cfg: &Config,
41        vb: ShardedVarBuilder,
42        is_gptx: bool,
43        normal_loading_metadata: NormalLoadingMetadata,
44        attention_mechanism: AttentionImplementation,
45    ) -> Result<Self> {
46        if cfg.use_sliding_window {
47            // TODO!
48            candle_core::bail!("Sliding window is unsupported for now!");
49        }
50        let vision = Qwen2VLVisionModel::new(
51            &cfg.vision_config,
52            vb.pp("visual")
53                .set_device(normal_loading_metadata.real_device.clone()),
54            &normal_loading_metadata.mapper.get_comm_for(0)?,
55        )?;
56        let text = Qwen2VLTextModel::new(
57            cfg,
58            vb.clone(),
59            is_gptx,
60            normal_loading_metadata,
61            attention_mechanism,
62        )?;
63        Ok(Self {
64            text,
65            vision,
66            spatial_merge_size: cfg.vision_config.spatial_merge_size,
67            image_token_id: cfg.image_token_id,
68            video_token_id: cfg.video_token_id,
69        })
70    }
71
72    #[allow(clippy::too_many_arguments)]
73    /// (position_ids, mrope_position_deltas)
74    fn get_rope_index(
75        &self,
76        input_ids: &Tensor,
77        image_grid_thw: Option<&Tensor>,
78        video_grid_thw: Option<&Tensor>,
79        attention_mask: Option<&Tensor>,
80        attention_mask_indices: Option<&Tensor>,
81        input_ids_searching: Vec<Vec<u32>>,
82        image_nums: Vec<usize>,
83        video_nums: Vec<usize>,
84    ) -> Result<(Tensor, Tensor)> {
85        if image_grid_thw.is_some() || video_grid_thw.is_some() {
86            let total_input_ids = input_ids.clone();
87            let mut position_ids = Tensor::zeros(
88                (3, input_ids.dim(0)?, input_ids.dim(1)?),
89                DType::I64,
90                input_ids.device(),
91            )?;
92            let mut mrope_position_deltas = Vec::new();
93
94            let mut image_index = 0;
95            let mut video_index = 0;
96            for (i, mut input_ids) in total_input_ids
97                .chunk(input_ids.dim(0)?, 0)?
98                .into_iter()
99                .enumerate()
100            {
101                if let Some(attention_mask_indices) = attention_mask_indices {
102                    input_ids = input_ids
103                        .i(i)?
104                        .to_dtype(DType::F32)?
105                        .index_select(&attention_mask_indices.squeeze(0)?, 0)?
106                        .to_dtype(input_ids.dtype())?;
107                }
108                let image_nums = image_nums[i];
109                let vision_nums = video_nums[i];
110
111                let mut llm_pos_ids: Vec<Tensor> = Vec::new();
112                let mut max_last_llm_pos_ids = None;
113                let mut max_llm_pos_ids = 0;
114                let mut st = 0;
115                let (mut remain_images, mut remain_videos) = (image_nums, vision_nums);
116                for _ in 0..(image_nums + vision_nums) {
117                    let ed_image = if input_ids_searching[i].contains(&self.image_token_id)
118                        && remain_images > 0
119                    {
120                        input_ids_searching[i][st..]
121                            .iter()
122                            .position(|&t| t == self.image_token_id)
123                            .unwrap()
124                            + st
125                    } else {
126                        input_ids.dim(0)? + 1
127                    };
128                    let ed_video = if input_ids_searching[i].contains(&self.video_token_id)
129                        && remain_videos > 0
130                    {
131                        input_ids_searching[i][st..]
132                            .iter()
133                            .position(|&t| t == self.video_token_id)
134                            .unwrap()
135                            + st
136                    } else {
137                        input_ids.dim(0)? + 1
138                    };
139                    let (ed, llm_grid_t, h, w) = if ed_image < ed_video {
140                        let t = image_grid_thw.as_ref().unwrap().i((image_index, 0))?;
141                        let h = image_grid_thw.as_ref().unwrap().i((image_index, 1))?;
142                        let w = image_grid_thw.as_ref().unwrap().i((image_index, 2))?;
143                        image_index += 1;
144                        remain_images -= 1;
145                        (
146                            ed_image,
147                            t.to_scalar::<u32>()?,
148                            h.to_scalar::<u32>()?,
149                            w.to_scalar::<u32>()?,
150                        )
151                    } else {
152                        let t = video_grid_thw.as_ref().unwrap().i((video_index, 0))?;
153                        let h = video_grid_thw.as_ref().unwrap().i((video_index, 1))?;
154                        let w = video_grid_thw.as_ref().unwrap().i((video_index, 2))?;
155                        video_index += 1;
156                        remain_videos -= 1;
157                        (
158                            ed_video,
159                            t.to_scalar::<u32>()?,
160                            h.to_scalar::<u32>()?,
161                            w.to_scalar::<u32>()?,
162                        )
163                    };
164                    let llm_grid_h = h / self.spatial_merge_size as u32;
165                    let llm_grid_w = w / self.spatial_merge_size as u32;
166                    let text_len = ed - st;
167
168                    let st_idx = max_last_llm_pos_ids.unwrap_or(0);
169                    // max_last_llm_pos_ids = Some(text_len as i64 + st_idx);
170                    max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
171                    llm_pos_ids.push(
172                        Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
173                            .unsqueeze(0)?
174                            .repeat((3, 1))?,
175                    );
176
177                    let t_idx = Tensor::arange(0, llm_grid_t as i64, input_ids.device())?
178                        .reshape(((), 1))?
179                        .repeat((1, llm_grid_h as usize * llm_grid_w as usize))?
180                        .flatten_all()?;
181                    let h_idx = Tensor::arange(0, llm_grid_h as i64, input_ids.device())?
182                        .reshape((1, (), 1))?
183                        .repeat((llm_grid_t as usize, 1, llm_grid_w as usize))?
184                        .flatten_all()?;
185                    let w_idx = Tensor::arange(0, llm_grid_w as i64, input_ids.device())?
186                        .reshape((1, 1, ()))?
187                        .repeat((llm_grid_t as usize, llm_grid_h as usize, 1))?
188                        .flatten_all()?;
189                    max_last_llm_pos_ids = Some(
190                        *[llm_grid_t, llm_grid_h, llm_grid_w].iter().max().unwrap() as i64
191                            + (text_len as i64 + st_idx),
192                    );
193                    max_llm_pos_ids = max_llm_pos_ids.max(max_last_llm_pos_ids.unwrap());
194                    llm_pos_ids.push(
195                        (Tensor::stack(&[t_idx, h_idx, w_idx], 0)?.to_dtype(DType::F32)?
196                            + (text_len + st_idx as usize) as f64)?
197                            .to_dtype(DType::I64)?,
198                    );
199                    st = ed + (llm_grid_t * llm_grid_h * llm_grid_w) as usize;
200                }
201
202                if st < input_ids.dim(0)? {
203                    let st_idx = max_last_llm_pos_ids.unwrap_or(0);
204                    let text_len = (input_ids.dim(0)? - st) as u32;
205                    // max_last_llm_pos_ids = Some(text_len as i64 + st_idx);
206                    max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
207                    llm_pos_ids.push(
208                        Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
209                            .reshape((1, ()))?
210                            .repeat((3, 1))?,
211                    );
212                }
213
214                let llm_positions = Tensor::cat(&llm_pos_ids, 1)?.reshape((3, ()))?;
215                let positions_mask = attention_mask
216                    .as_ref()
217                    .unwrap()
218                    .i(i)?
219                    .eq(1f64)?
220                    .unsqueeze(0)?
221                    .repeat((3, 1))?;
222
223                position_ids = position_ids.slice_assign(
224                    &[&.., &i, &..],
225                    &positions_mask
226                        .where_cond(&llm_positions, &position_ids.i((.., i, ..))?)?
227                        .unsqueeze(1)?,
228                )?;
229                mrope_position_deltas
230                    .push(max_llm_pos_ids + 1 - total_input_ids.i(i)?.dim(0)? as i64);
231            }
232            let mrope_position_deltas_len = mrope_position_deltas.len();
233            let mrope_position_deltas = Tensor::from_vec(
234                mrope_position_deltas,
235                (mrope_position_deltas_len,),
236                input_ids.device(),
237            )?
238            .unsqueeze(1)?;
239            Ok((position_ids, mrope_position_deltas))
240        } else if let Some(attention_mask) = attention_mask {
241            let position_ids = (attention_mask.to_dtype(DType::F32)?.cumsum(D::Minus1)? - 1f64)?;
242            let position_ids = masked_fill(&position_ids, &attention_mask.eq(0f64)?, 1i64)?;
243            let position_ids = position_ids.unsqueeze(0)?.repeat((3, 1, 1))?;
244
245            let max_position_ids = position_ids.max(0)?.max_keepdim(D::Minus1)?;
246            let mrope_position_deltas =
247                ((max_position_ids + 1.)? - attention_mask.dim(D::Minus1)? as f64)?;
248
249            Ok((
250                position_ids.to_dtype(DType::I64)?,
251                mrope_position_deltas.to_dtype(DType::I64)?,
252            ))
253        } else {
254            let position_ids = Tensor::arange(0i64, input_ids.dim(1)? as i64, input_ids.device())?
255                .reshape((1, 1, ()))?
256                .repeat((3, input_ids.dim(0)?, 1))?;
257            let mrope_position_deltas =
258                Tensor::zeros((input_ids.dim(0)?, 1), DType::I64, input_ids.device())?;
259
260            Ok((position_ids, mrope_position_deltas))
261        }
262    }
263
264    #[allow(clippy::too_many_arguments)]
265    pub fn forward(
266        &self,
267        input_ids: &Tensor,
268        input_ids_full: &Tensor,
269        pixel_values: Option<Tensor>,
270        pixel_values_videos: Option<Tensor>,
271        image_grid_thw: Option<Tensor>,
272        video_grid_thw: Option<Tensor>,
273        seqlens: Vec<usize>,
274        continuous_img_pad: Vec<Vec<(usize, usize)>>,
275        continuous_vid_pad: Vec<Vec<(usize, usize)>>,
276        input_ids_searching: Vec<Vec<u32>>,
277        image_nums: Vec<usize>,
278        video_nums: Vec<usize>,
279        seqlen_offsets: &[usize],
280        context_lens: Vec<(usize, usize)>,
281        flash_params: &FlashParams,
282    ) -> Result<Tensor> {
283        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
284            input_ids,
285            &seqlen_offsets as &dyn PastKvLenCache,
286            self.text.cfg.sliding_window,
287            self.text.dtype,
288            self.text.cfg.num_attn_heads,
289        )?;
290
291        let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() {
292            let mut xs = self.text.embed_tokens(input_ids)?;
293
294            if let Some(pixel_values) = pixel_values {
295                let image_embeds = self
296                    .vision
297                    .forward(
298                        &pixel_values,
299                        image_grid_thw
300                            .as_ref()
301                            .context("pixel_values require image_grid_thw")?,
302                    )?
303                    .to_dtype(self.text.dtype)?;
304
305                for (batch, batch_ids) in continuous_img_pad.into_iter().enumerate() {
306                    let mut last_end = 0;
307                    for (start, end) in batch_ids {
308                        xs = xs.slice_assign(
309                            &[&batch, &(start..end), &..],
310                            &image_embeds
311                                .i((last_end..last_end + (end - start), ..))?
312                                .unsqueeze(0)?,
313                        )?;
314                        last_end = end - start;
315                    }
316                }
317            }
318
319            if let Some(pixel_values_videos) = pixel_values_videos {
320                let video_embeds = self.vision.forward(
321                    &pixel_values_videos,
322                    video_grid_thw
323                        .as_ref()
324                        .context("pixel_values_videos require video_grid_thw")?,
325                )?;
326
327                for (batch, batch_ids) in continuous_vid_pad.into_iter().enumerate() {
328                    let mut last_end = 0;
329                    for (start, end) in batch_ids {
330                        xs = xs.slice_assign(
331                            &[&batch, &(start..end), &..],
332                            &video_embeds
333                                .i((last_end..last_end + (end - start), ..))?
334                                .unsqueeze(0)?,
335                        )?;
336                        last_end = end - start;
337                    }
338                }
339            }
340
341            xs
342        } else {
343            self.text.embed_tokens(input_ids)?
344        };
345
346        let mut ropeidx_attn_mask_bs = Vec::new();
347        let max_seqlens = *seqlens.iter().max().unwrap();
348        for len in &seqlens {
349            ropeidx_attn_mask_bs.push(Tensor::new(
350                [vec![1f32; *len], vec![0f32; max_seqlens - len]].concat(),
351                input_ids.device(),
352            )?);
353        }
354        let ropeidx_attn_mask = Tensor::stack(&ropeidx_attn_mask_bs, 0)?;
355        let mut ropeidx_attn_mask_indices_bs = Vec::new();
356        for len in seqlens {
357            ropeidx_attn_mask_indices_bs.push(Tensor::from_vec(
358                (0..len as i64).collect(),
359                (len,),
360                input_ids.device(),
361            )?);
362        }
363        let ropeidx_attn_mask_indices = Tensor::stack(&ropeidx_attn_mask_indices_bs, 0)?;
364
365        let ropeidx_input_ids = if attention_mask.is_some() {
366            input_ids
367        } else {
368            input_ids_full
369        };
370        let (position_ids, mrope_position_deltas) = self.get_rope_index(
371            ropeidx_input_ids,
372            image_grid_thw.as_ref(),
373            video_grid_thw.as_ref(),
374            Some(&ropeidx_attn_mask),
375            Some(&ropeidx_attn_mask_indices),
376            input_ids_searching,
377            image_nums,
378            video_nums,
379        )?;
380
381        let position_ids = if attention_mask.is_some() {
382            position_ids
383        } else {
384            let mut position_ids = Tensor::new(
385                seqlen_offsets.iter().map(|x| *x as i64).collect::<Vec<_>>(),
386                input_ids.device(),
387            )?
388            .reshape((1, (), 1))?
389            .repeat((3, 1, 1))?;
390
391            position_ids = position_ids.broadcast_add(&mrope_position_deltas.unsqueeze(0)?)?;
392
393            position_ids
394        };
395
396        let out = self.text.forward_embeds(
397            input_embeds,
398            attention_mask.as_ref(),
399            &position_ids,
400            context_lens,
401            flash_params,
402        )?;
403        Ok(out)
404    }
405}
406
407pub(crate) struct Qwen2VLVisionSpecificArgs {
408    input_ids_full: Tensor,
409    image_grid_thw: Option<Tensor>, // Some when pixel values are provided
410    video_grid_thw: Option<Tensor>, // Some when pixel values are provided
411    seqlens: Vec<usize>,
412    continuous_img_pad: Vec<Vec<(usize, usize)>>,
413    continuous_vid_pad: Vec<Vec<(usize, usize)>>,
414    input_ids_searching: Vec<Vec<u32>>,
415    image_nums: Vec<usize>,
416    video_nums: Vec<usize>,
417}
418
419impl VisionModel for Qwen2VLModel {
420    fn forward(
421        &self,
422        input_ids: &Tensor,
423        pixel_values: Option<Tensor>,
424        seqlen_offsets: &[usize],
425        context_lens: Vec<(usize, usize)>,
426        _position_ids: Vec<usize>,
427        model_specific_args: Box<dyn Any>,
428        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
429        flash_params: &FlashParams,
430    ) -> Result<Tensor> {
431        let Qwen2VLVisionSpecificArgs {
432            input_ids_full,
433            image_grid_thw,
434            video_grid_thw,
435            seqlens,
436            continuous_img_pad,
437            continuous_vid_pad,
438            input_ids_searching,
439            image_nums,
440            video_nums,
441        } = *model_specific_args
442            .downcast()
443            .expect("Cannot downcast into `Qwen2VLVisionSpecificArgs`");
444        let (pixel_values, pixel_values_video) = match (&image_grid_thw, &video_grid_thw) {
445            (Some(_), None) => (pixel_values, None),
446            (None, Some(_)) => (None, pixel_values),
447            (None, None) => (None, None),
448            (Some(_), Some(_)) => {
449                candle_core::bail!("Images and videos cannot be provided together.")
450            }
451        };
452        self.forward(
453            input_ids,
454            &input_ids_full,
455            pixel_values,
456            pixel_values_video,
457            image_grid_thw,
458            video_grid_thw,
459            seqlens,
460            continuous_img_pad,
461            continuous_vid_pad,
462            input_ids_searching,
463            image_nums,
464            video_nums,
465            seqlen_offsets,
466            context_lens,
467            flash_params,
468        )
469    }
470    fn cache(&self) -> &EitherCache {
471        &self.text.cache
472    }
473    fn cache_mut(&mut self) -> &mut EitherCache {
474        &mut self.text.cache
475    }
476    fn device(&self) -> &Device {
477        &self.text.device
478    }
479    fn max_seq_len(&self) -> usize {
480        self.text.max_seq_len
481    }
482    fn has_conv2d(&self) -> bool {
483        true
484    }
485    fn config(&self) -> &ModelConfigMetadata {
486        &self.text.cfg
487    }
488    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any> {
489        assert_eq!(input_ids.dims()[0], 1);
490        Box::new(Qwen2VLVisionSpecificArgs {
491            input_ids_full: input_ids.clone(),
492            image_grid_thw: None,
493            video_grid_thw: None,
494            seqlens: vec![input_ids.dims()[1]],
495            continuous_img_pad: vec![],
496            continuous_vid_pad: vec![],
497            input_ids_searching: vec![vec![]; input_ids.dims()[0]],
498            image_nums: vec![0; input_ids.dims()[0]],
499            video_nums: vec![0; input_ids.dims()[0]],
500        })
501    }
502}
503
504impl IsqModel for Qwen2VLModel {
505    fn get_layers(
506        &mut self,
507    ) -> (
508        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
509        &dyn DeviceMapper,
510    ) {
511        self.text.get_layers()
512    }
513    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
514        self.text.residual_tensors()
515    }
516}
517
518impl AnyMoeBaseModelMixin for Qwen2VLModel {}