mistralrs_core/vision_models/qwen2vl/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use text::Qwen2VLTextModel;
8use vision::Qwen2VLVisionModel;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    layers::CausalMasker,
14    layers_masker::{masked_fill, PastKvLenCache},
15    paged_attention::{AttentionImplementation, ModelConfigMetadata},
16    pipeline::{
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
19    },
20};
21
22mod config;
23mod inputs_processor;
24mod text;
25mod vision;
26
27pub(crate) use config::Config;
28pub(crate) use inputs_processor::Qwen2VLProcessor;
29
30pub struct Qwen2VLModel {
31    text: Qwen2VLTextModel,
32    vision: Qwen2VLVisionModel,
33    spatial_merge_size: usize,
34    image_token_id: u32,
35    video_token_id: u32,
36}
37
38impl Qwen2VLModel {
39    pub fn new(
40        cfg: &Config,
41        vb: ShardedVarBuilder,
42        is_gptx: bool,
43        normal_loading_metadata: NormalLoadingMetadata,
44        attention_mechanism: AttentionImplementation,
45    ) -> Result<Self> {
46        if cfg.use_sliding_window {
47            // TODO!
48            candle_core::bail!("Sliding window is unsupported for now!");
49        }
50        // Support both HuggingFace naming (visual.*) and MLX naming (vision_tower.*)
51        let vision_vb = if vb.contains_tensor("vision_tower.patch_embed.proj.weight") {
52            vb.pp("vision_tower")
53        } else {
54            vb.pp("visual")
55        };
56        let vision = Qwen2VLVisionModel::new(
57            &cfg.vision_config,
58            vision_vb.set_device(normal_loading_metadata.real_device.clone()),
59            &normal_loading_metadata.mapper.get_comm_for(0)?,
60        )?;
61        let text = Qwen2VLTextModel::new(
62            cfg,
63            vb.clone(),
64            is_gptx,
65            normal_loading_metadata,
66            attention_mechanism,
67        )?;
68        Ok(Self {
69            text,
70            vision,
71            spatial_merge_size: cfg.vision_config.spatial_merge_size,
72            image_token_id: cfg.image_token_id,
73            video_token_id: cfg.video_token_id,
74        })
75    }
76
77    #[allow(clippy::too_many_arguments)]
78    /// (position_ids, mrope_position_deltas)
79    fn get_rope_index(
80        &self,
81        input_ids: &Tensor,
82        image_grid_thw: Option<&Tensor>,
83        video_grid_thw: Option<&Tensor>,
84        attention_mask: Option<&Tensor>,
85        attention_mask_indices: Option<&Tensor>,
86        input_ids_searching: Vec<Vec<u32>>,
87        image_nums: Vec<usize>,
88        video_nums: Vec<usize>,
89    ) -> Result<(Tensor, Tensor)> {
90        if image_grid_thw.is_some() || video_grid_thw.is_some() {
91            let total_input_ids = input_ids.clone();
92            let mut position_ids = Tensor::zeros(
93                (3, input_ids.dim(0)?, input_ids.dim(1)?),
94                DType::I64,
95                input_ids.device(),
96            )?;
97            let mut mrope_position_deltas = Vec::new();
98
99            let mut image_index = 0;
100            let mut video_index = 0;
101            for (i, mut input_ids) in total_input_ids
102                .chunk(input_ids.dim(0)?, 0)?
103                .into_iter()
104                .enumerate()
105            {
106                if let Some(attention_mask_indices) = attention_mask_indices {
107                    input_ids = input_ids
108                        .i(i)?
109                        .to_dtype(DType::F32)?
110                        .index_select(&attention_mask_indices.squeeze(0)?, 0)?
111                        .to_dtype(input_ids.dtype())?;
112                }
113                let image_nums = image_nums[i];
114                let vision_nums = video_nums[i];
115
116                let mut llm_pos_ids: Vec<Tensor> = Vec::new();
117                let mut max_last_llm_pos_ids = None;
118                let mut max_llm_pos_ids = 0;
119                let mut st = 0;
120                let (mut remain_images, mut remain_videos) = (image_nums, vision_nums);
121                for _ in 0..(image_nums + vision_nums) {
122                    let ed_image = if input_ids_searching[i].contains(&self.image_token_id)
123                        && remain_images > 0
124                    {
125                        input_ids_searching[i][st..]
126                            .iter()
127                            .position(|&t| t == self.image_token_id)
128                            .unwrap()
129                            + st
130                    } else {
131                        input_ids.dim(0)? + 1
132                    };
133                    let ed_video = if input_ids_searching[i].contains(&self.video_token_id)
134                        && remain_videos > 0
135                    {
136                        input_ids_searching[i][st..]
137                            .iter()
138                            .position(|&t| t == self.video_token_id)
139                            .unwrap()
140                            + st
141                    } else {
142                        input_ids.dim(0)? + 1
143                    };
144                    let (ed, llm_grid_t, h, w) = if ed_image < ed_video {
145                        let t = image_grid_thw.as_ref().unwrap().i((image_index, 0))?;
146                        let h = image_grid_thw.as_ref().unwrap().i((image_index, 1))?;
147                        let w = image_grid_thw.as_ref().unwrap().i((image_index, 2))?;
148                        image_index += 1;
149                        remain_images -= 1;
150                        (
151                            ed_image,
152                            t.to_scalar::<u32>()?,
153                            h.to_scalar::<u32>()?,
154                            w.to_scalar::<u32>()?,
155                        )
156                    } else {
157                        let t = video_grid_thw.as_ref().unwrap().i((video_index, 0))?;
158                        let h = video_grid_thw.as_ref().unwrap().i((video_index, 1))?;
159                        let w = video_grid_thw.as_ref().unwrap().i((video_index, 2))?;
160                        video_index += 1;
161                        remain_videos -= 1;
162                        (
163                            ed_video,
164                            t.to_scalar::<u32>()?,
165                            h.to_scalar::<u32>()?,
166                            w.to_scalar::<u32>()?,
167                        )
168                    };
169                    let llm_grid_h = h / self.spatial_merge_size as u32;
170                    let llm_grid_w = w / self.spatial_merge_size as u32;
171                    let text_len = ed - st;
172
173                    let st_idx = max_last_llm_pos_ids.unwrap_or(0);
174                    // max_last_llm_pos_ids = Some(text_len as i64 + st_idx);
175                    max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
176                    llm_pos_ids.push(
177                        Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
178                            .unsqueeze(0)?
179                            .repeat((3, 1))?,
180                    );
181
182                    let t_idx = Tensor::arange(0, llm_grid_t as i64, input_ids.device())?
183                        .reshape(((), 1))?
184                        .repeat((1, llm_grid_h as usize * llm_grid_w as usize))?
185                        .flatten_all()?;
186                    let h_idx = Tensor::arange(0, llm_grid_h as i64, input_ids.device())?
187                        .reshape((1, (), 1))?
188                        .repeat((llm_grid_t as usize, 1, llm_grid_w as usize))?
189                        .flatten_all()?;
190                    let w_idx = Tensor::arange(0, llm_grid_w as i64, input_ids.device())?
191                        .reshape((1, 1, ()))?
192                        .repeat((llm_grid_t as usize, llm_grid_h as usize, 1))?
193                        .flatten_all()?;
194                    max_last_llm_pos_ids = Some(
195                        *[llm_grid_t, llm_grid_h, llm_grid_w].iter().max().unwrap() as i64
196                            + (text_len as i64 + st_idx),
197                    );
198                    max_llm_pos_ids = max_llm_pos_ids.max(max_last_llm_pos_ids.unwrap());
199                    llm_pos_ids.push(
200                        (Tensor::stack(&[t_idx, h_idx, w_idx], 0)?.to_dtype(DType::F32)?
201                            + (text_len + st_idx as usize) as f64)?
202                            .to_dtype(DType::I64)?,
203                    );
204                    st = ed + (llm_grid_t * llm_grid_h * llm_grid_w) as usize;
205                }
206
207                if st < input_ids.dim(0)? {
208                    let st_idx = max_last_llm_pos_ids.unwrap_or(0);
209                    let text_len = (input_ids.dim(0)? - st) as u32;
210                    // max_last_llm_pos_ids = Some(text_len as i64 + st_idx);
211                    max_llm_pos_ids = max_llm_pos_ids.max(text_len as i64 + st_idx);
212                    llm_pos_ids.push(
213                        Tensor::arange(st_idx, text_len as i64 + st_idx, input_ids.device())?
214                            .reshape((1, ()))?
215                            .repeat((3, 1))?,
216                    );
217                }
218
219                let llm_positions = Tensor::cat(&llm_pos_ids, 1)?.reshape((3, ()))?;
220                let positions_mask = attention_mask
221                    .as_ref()
222                    .unwrap()
223                    .i(i)?
224                    .eq(1f64)?
225                    .unsqueeze(0)?
226                    .repeat((3, 1))?;
227
228                position_ids = position_ids.slice_assign(
229                    &[0..position_ids.dim(0)?, i..i + 1, 0..position_ids.dim(2)?],
230                    &positions_mask
231                        .where_cond(&llm_positions, &position_ids.i((.., i, ..))?)?
232                        .unsqueeze(1)?,
233                )?;
234                mrope_position_deltas
235                    .push(max_llm_pos_ids + 1 - total_input_ids.i(i)?.dim(0)? as i64);
236            }
237            let mrope_position_deltas_len = mrope_position_deltas.len();
238            let mrope_position_deltas = Tensor::from_vec(
239                mrope_position_deltas,
240                (mrope_position_deltas_len,),
241                input_ids.device(),
242            )?
243            .unsqueeze(1)?;
244            Ok((position_ids, mrope_position_deltas))
245        } else if let Some(attention_mask) = attention_mask {
246            let position_ids = (attention_mask.to_dtype(DType::F32)?.cumsum(D::Minus1)? - 1f64)?;
247            let position_ids = masked_fill(&position_ids, &attention_mask.eq(0f64)?, 1i64)?;
248            let position_ids = position_ids.unsqueeze(0)?.repeat((3, 1, 1))?;
249
250            let max_position_ids = position_ids.max(0)?.max_keepdim(D::Minus1)?;
251            let mrope_position_deltas =
252                ((max_position_ids + 1.)? - attention_mask.dim(D::Minus1)? as f64)?;
253
254            Ok((
255                position_ids.to_dtype(DType::I64)?,
256                mrope_position_deltas.to_dtype(DType::I64)?,
257            ))
258        } else {
259            let position_ids = Tensor::arange(0i64, input_ids.dim(1)? as i64, input_ids.device())?
260                .reshape((1, 1, ()))?
261                .repeat((3, input_ids.dim(0)?, 1))?;
262            let mrope_position_deltas =
263                Tensor::zeros((input_ids.dim(0)?, 1), DType::I64, input_ids.device())?;
264
265            Ok((position_ids, mrope_position_deltas))
266        }
267    }
268
269    #[allow(clippy::too_many_arguments)]
270    pub fn forward(
271        &self,
272        input_ids: &Tensor,
273        input_ids_full: &Tensor,
274        pixel_values: Option<Tensor>,
275        pixel_values_videos: Option<Tensor>,
276        image_grid_thw: Option<Tensor>,
277        video_grid_thw: Option<Tensor>,
278        seqlens: Vec<usize>,
279        continuous_img_pad: Vec<Vec<(usize, usize)>>,
280        continuous_vid_pad: Vec<Vec<(usize, usize)>>,
281        input_ids_searching: Vec<Vec<u32>>,
282        image_nums: Vec<usize>,
283        video_nums: Vec<usize>,
284        seqlen_offsets: &[usize],
285        context_lens: Vec<(usize, usize)>,
286        flash_params: &FlashParams,
287    ) -> Result<Tensor> {
288        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
289            input_ids,
290            &seqlen_offsets as &dyn PastKvLenCache,
291            self.text.cfg.sliding_window,
292            self.text.dtype,
293            self.text.cfg.num_attn_heads,
294        )?;
295
296        let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() {
297            let mut xs = self.text.embed_tokens(input_ids)?;
298
299            if let Some(pixel_values) = pixel_values {
300                let image_embeds = self
301                    .vision
302                    .forward(
303                        &pixel_values,
304                        image_grid_thw
305                            .as_ref()
306                            .context("pixel_values require image_grid_thw")?,
307                    )?
308                    .to_dtype(self.text.dtype)?;
309
310                for (batch, batch_ids) in continuous_img_pad.into_iter().enumerate() {
311                    let mut last_end = 0;
312                    for (start, end) in batch_ids {
313                        xs = xs.slice_assign(
314                            &[batch..batch + 1, start..end, 0..xs.dim(2)?],
315                            &image_embeds
316                                .i((last_end..last_end + (end - start), ..))?
317                                .unsqueeze(0)?,
318                        )?;
319                        last_end = end - start;
320                    }
321                }
322            }
323
324            if let Some(pixel_values_videos) = pixel_values_videos {
325                let video_embeds = self.vision.forward(
326                    &pixel_values_videos,
327                    video_grid_thw
328                        .as_ref()
329                        .context("pixel_values_videos require video_grid_thw")?,
330                )?;
331
332                for (batch, batch_ids) in continuous_vid_pad.into_iter().enumerate() {
333                    let mut last_end = 0;
334                    for (start, end) in batch_ids {
335                        xs = xs.slice_assign(
336                            &[batch..batch + 1, start..end, 0..xs.dim(2)?],
337                            &video_embeds
338                                .i((last_end..last_end + (end - start), ..))?
339                                .unsqueeze(0)?,
340                        )?;
341                        last_end = end - start;
342                    }
343                }
344            }
345
346            xs
347        } else {
348            self.text.embed_tokens(input_ids)?
349        };
350
351        let mut ropeidx_attn_mask_bs = Vec::new();
352        let max_seqlens = *seqlens.iter().max().unwrap();
353        for len in &seqlens {
354            ropeidx_attn_mask_bs.push(Tensor::new(
355                [vec![1f32; *len], vec![0f32; max_seqlens - len]].concat(),
356                input_ids.device(),
357            )?);
358        }
359        let ropeidx_attn_mask = Tensor::stack(&ropeidx_attn_mask_bs, 0)?;
360        let mut ropeidx_attn_mask_indices_bs = Vec::new();
361        for (len, offset) in seqlens.iter().zip(seqlen_offsets) {
362            ropeidx_attn_mask_indices_bs.push(Tensor::from_vec(
363                (*offset as i64..(*len as i64 + *offset as i64)).collect(),
364                (*len,),
365                input_ids.device(),
366            )?);
367        }
368        let ropeidx_attn_mask_indices = Tensor::stack(&ropeidx_attn_mask_indices_bs, 0)?;
369
370        let ropeidx_input_ids = if attention_mask.is_some() {
371            input_ids
372        } else {
373            input_ids_full
374        };
375        let (position_ids, mrope_position_deltas) = self.get_rope_index(
376            ropeidx_input_ids,
377            image_grid_thw.as_ref(),
378            video_grid_thw.as_ref(),
379            Some(&ropeidx_attn_mask),
380            Some(&ropeidx_attn_mask_indices),
381            input_ids_searching,
382            image_nums,
383            video_nums,
384        )?;
385
386        let position_ids = if attention_mask.is_some() {
387            position_ids
388        } else {
389            let mut position_ids = Tensor::new(
390                seqlen_offsets.iter().map(|x| *x as i64).collect::<Vec<_>>(),
391                input_ids.device(),
392            )?
393            .reshape((1, (), 1))?
394            .repeat((3, 1, 1))?;
395
396            position_ids = position_ids.broadcast_add(&mrope_position_deltas.unsqueeze(0)?)?;
397
398            position_ids
399        };
400
401        let out = self.text.forward_embeds(
402            input_embeds,
403            attention_mask.as_ref(),
404            &position_ids,
405            context_lens,
406            flash_params,
407        )?;
408        Ok(out)
409    }
410}
411
412pub(crate) struct Qwen2VLVisionSpecificArgs {
413    input_ids_full: Tensor,
414    image_grid_thw: Option<Tensor>, // Some when pixel values are provided
415    video_grid_thw: Option<Tensor>, // Some when pixel values are provided
416    seqlens: Vec<usize>,
417    continuous_img_pad: Vec<Vec<(usize, usize)>>,
418    continuous_vid_pad: Vec<Vec<(usize, usize)>>,
419    input_ids_searching: Vec<Vec<u32>>,
420    image_nums: Vec<usize>,
421    video_nums: Vec<usize>,
422}
423
424impl VisionModel for Qwen2VLModel {
425    fn forward(
426        &self,
427        input_ids: &Tensor,
428        pixel_values: Option<Tensor>,
429        seqlen_offsets: &[usize],
430        context_lens: Vec<(usize, usize)>,
431        _position_ids: Vec<usize>,
432        model_specific_args: Box<dyn Any>,
433        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
434        flash_params: &FlashParams,
435    ) -> Result<Tensor> {
436        let Qwen2VLVisionSpecificArgs {
437            input_ids_full,
438            image_grid_thw,
439            video_grid_thw,
440            seqlens,
441            continuous_img_pad,
442            continuous_vid_pad,
443            input_ids_searching,
444            image_nums,
445            video_nums,
446        } = *model_specific_args
447            .downcast()
448            .expect("Cannot downcast into `Qwen2VLVisionSpecificArgs`");
449        let (pixel_values, pixel_values_video) = match (&image_grid_thw, &video_grid_thw) {
450            (Some(_), None) => (pixel_values, None),
451            (None, Some(_)) => (None, pixel_values),
452            (None, None) => (None, None),
453            (Some(_), Some(_)) => {
454                candle_core::bail!("Images and videos cannot be provided together.")
455            }
456        };
457        self.forward(
458            input_ids,
459            &input_ids_full,
460            pixel_values,
461            pixel_values_video,
462            image_grid_thw,
463            video_grid_thw,
464            seqlens,
465            continuous_img_pad,
466            continuous_vid_pad,
467            input_ids_searching,
468            image_nums,
469            video_nums,
470            seqlen_offsets,
471            context_lens,
472            flash_params,
473        )
474    }
475    fn cache(&self) -> &EitherCache {
476        &self.text.cache
477    }
478    fn cache_mut(&mut self) -> &mut EitherCache {
479        &mut self.text.cache
480    }
481    fn device(&self) -> &Device {
482        &self.text.device
483    }
484    fn max_seq_len(&self) -> usize {
485        self.text.max_seq_len
486    }
487    fn config(&self) -> &ModelConfigMetadata {
488        &self.text.cfg
489    }
490    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any> {
491        assert_eq!(input_ids.dims()[0], 1);
492        Box::new(Qwen2VLVisionSpecificArgs {
493            input_ids_full: input_ids.clone(),
494            image_grid_thw: None,
495            video_grid_thw: None,
496            seqlens: vec![input_ids.dims()[1]],
497            continuous_img_pad: vec![],
498            continuous_vid_pad: vec![],
499            input_ids_searching: vec![vec![]; input_ids.dims()[0]],
500            image_nums: vec![0; input_ids.dims()[0]],
501            video_nums: vec![0; input_ids.dims()[0]],
502        })
503    }
504}
505
506impl IsqModel for Qwen2VLModel {
507    fn get_layers(
508        &mut self,
509    ) -> (
510        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
511        &dyn DeviceMapper,
512    ) {
513        self.text.get_layers()
514    }
515    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
516        self.text.residual_tensors()
517    }
518}
519
520impl AnyMoeBaseModelMixin for Qwen2VLModel {}