mistralrs_core/vision_models/minicpmo/
mod.rs

1use std::{any::Any, sync::Arc};
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4pub use config::MiniCpmOConfig;
5pub use inputs_processor::MiniCpmOProcessor;
6use mistralrs_quant::{CollectedImatrixData, QuantMethod, ShardedVarBuilder};
7use resampler::Resampler;
8
9use crate::{
10    amoe::AnyMoeBaseModelMixin,
11    device_map::DeviceMapper,
12    models::qwen2,
13    paged_attention::{AttentionImplementation, ModelConfigMetadata},
14    pipeline::{
15        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
17    },
18    utils::unvarbuilder::UnVarBuilder,
19};
20
21use self::siglip::SiglipVisionTransformer;
22
23use super::siglip;
24
25mod config;
26mod inputs_processor;
27mod resampler;
28
29pub struct MiniCpmOModel {
30    cfg: MiniCpmOConfig,
31    llm: qwen2::Model,
32    vpm: SiglipVisionTransformer,
33    resampler: Resampler,
34}
35
36impl MiniCpmOModel {
37    pub fn new(
38        cfg: &MiniCpmOConfig,
39        vb: ShardedVarBuilder,
40        is_gptx: bool,
41        normal_loading_metadata: NormalLoadingMetadata,
42        attention_mechanism: AttentionImplementation,
43    ) -> Result<Self> {
44        let real_device = normal_loading_metadata.real_device.clone();
45        let llm = qwen2::Model::new(
46            &cfg.text_config,
47            vb.pp("llm"),
48            is_gptx,
49            normal_loading_metadata,
50            attention_mechanism,
51        )?;
52        let vpm = SiglipVisionTransformer::new(
53            &cfg.vision_config,
54            vb.pp("vpm").set_device(real_device.clone()),
55        )?;
56        let resampler = Resampler::new(
57            cfg.query_num,
58            cfg.text_config.hidden_size,
59            cfg.text_config.hidden_size / 128,
60            cfg.vision_config.hidden_size,
61            true,
62            None,
63            vb.pp("resampler").set_device(real_device.clone()),
64        )?;
65        Ok(Self {
66            cfg: cfg.clone(),
67            llm,
68            vpm,
69            resampler,
70        })
71    }
72
73    fn get_vllm_embedding(
74        &self,
75        input_ids: &Tensor,
76        device: &Device,
77        pixel_values_all: Option<Vec<Vec<Tensor>>>,
78        tgt_sizes: Option<Vec<Tensor>>,
79        image_bound: Option<Vec<Tensor>>,
80    ) -> Result<Tensor> {
81        let mut vllm_embedding = self.llm.get_input_embeddings(input_ids)?;
82
83        if let Some(pixel_values_all) = pixel_values_all {
84            let tgt_sizes_all = tgt_sizes.as_ref().expect("Need tgt_sizes");
85            let image_bound = image_bound.expect("Need image_bound");
86            let image_bound_vec = image_bound
87                .into_iter()
88                .map(|x| x.to_vec2::<u32>())
89                .collect::<Result<Vec<_>>>()?;
90
91            let mut all_pixel_values = Vec::new();
92            let mut img_cnts = Vec::new();
93            for pixel_values in &pixel_values_all {
94                img_cnts.push(pixel_values.len());
95                let mut imgs = Vec::new();
96                for i in pixel_values {
97                    // Assume channel dimension first
98                    imgs.push(i.flatten_to(1)?.permute((1, 0))?);
99                }
100                all_pixel_values.extend(imgs);
101            }
102
103            let tgt_sizes = Tensor::cat(tgt_sizes_all, 0)?;
104            let tgt_sizes_vec = tgt_sizes.to_vec2::<u32>()?;
105
106            let max_patches = (tgt_sizes.i((.., 0))? * tgt_sizes.i((.., 1))?)?
107                .max(0)?
108                .to_scalar::<u32>()? as usize;
109
110            // Original code does padding of the pixel values here
111            let lens = all_pixel_values
112                .iter()
113                .map(|pixel_values| pixel_values.dim(0))
114                .collect::<Result<Vec<_>>>()?;
115            let max_len = lens.into_iter().max().expect("No pixel values somehow?");
116            all_pixel_values = all_pixel_values
117                .into_iter()
118                .map(|pixel_values| {
119                    pixel_values.pad_with_zeros(0, 0, max_len - pixel_values.dim(0)?)
120                })
121                .collect::<Result<Vec<_>>>()?;
122            let mut all_pixel_values = Tensor::stack(&all_pixel_values, 0)?;
123
124            let (b, l, _) = all_pixel_values.dims3()?;
125            all_pixel_values = all_pixel_values
126                .permute((0, 2, 1))?
127                .reshape((b, 3, (), l))?;
128
129            let mut patch_attn_mask = Tensor::zeros((b, 1, max_patches), DType::U8, device)?;
130            for (i, tgt_sizes_vec_i) in tgt_sizes_vec.iter().enumerate().take(b) {
131                let n = (tgt_sizes_vec_i[0] * tgt_sizes_vec_i[1]) as usize;
132                patch_attn_mask = patch_attn_mask.slice_assign(
133                    &[&i, &0, &(..n)],
134                    &Tensor::ones((1, 1, n), DType::U8, device)?,
135                )?;
136            }
137
138            let vision_batch_size = self.cfg.vision_batch_size;
139            all_pixel_values = all_pixel_values.to_dtype(self.llm.embed_dtype())?;
140
141            let mut vision_embedding = if b > vision_batch_size {
142                let mut hs = Vec::new();
143                for i in (0..b).step_by(vision_batch_size) {
144                    let start_idx = i;
145                    let end_idx = i + vision_batch_size;
146                    let tmp_hs = self.vpm.forward(
147                        &all_pixel_values.i(start_idx..end_idx)?,
148                        Some(&patch_attn_mask.i(start_idx..end_idx)?),
149                        Some(&tgt_sizes.i(start_idx..end_idx)?),
150                    )?;
151                    hs.push(tmp_hs);
152                }
153                Tensor::cat(&hs, 0)?
154            } else {
155                self.vpm
156                    .forward(&all_pixel_values, Some(&patch_attn_mask), Some(&tgt_sizes))?
157            };
158            vision_embedding = self.resampler.forward(&vision_embedding, &tgt_sizes_vec)?;
159
160            let mut start = 0;
161            let mut vision_hidden_states = Vec::new();
162            for pixel_values in &pixel_values_all {
163                let img_cnt = pixel_values.len();
164                if img_cnt > 0 {
165                    vision_hidden_states.push(Some(
166                        vision_embedding
167                            .i(start..start + img_cnt)?
168                            .to_dtype(vllm_embedding.dtype())?,
169                    ));
170                    start += img_cnt;
171                } else {
172                    vision_hidden_states.push(None);
173                }
174            }
175
176            let mut new_vllm_embedding = Vec::new();
177            for i in 0..input_ids.dim(0)? {
178                if let Some(cur_vs_hs) = &vision_hidden_states[i] {
179                    let mut cur_vllm_emb = vllm_embedding.i(i)?;
180                    let cur_image_bound = &image_bound_vec[i];
181                    if !cur_image_bound.is_empty() {
182                        let mut image_indices = Vec::new();
183                        for r in cur_image_bound {
184                            image_indices.push(Tensor::arange(r[0], r[1], device)?);
185                        }
186                        let image_indices = Tensor::stack(&image_indices, 0)?;
187
188                        let indices = image_indices
189                            .reshape(((), 1))?
190                            .repeat((1, cur_vllm_emb.dim(D::Minus1)?))?;
191                        // Zero out the current data
192                        let cur_vllm_emb_neg = cur_vllm_emb.gather(&indices, 0)?.neg()?;
193                        cur_vllm_emb = cur_vllm_emb.scatter_add(&indices, &cur_vllm_emb_neg, 0)?;
194                        // Add the image data
195                        cur_vllm_emb = cur_vllm_emb.scatter_add(
196                            &indices,
197                            &cur_vs_hs.reshape(((), cur_vs_hs.dim(D::Minus1)?))?,
198                            0,
199                        )?;
200                        new_vllm_embedding.push(cur_vllm_emb);
201                    }
202                }
203            }
204            vllm_embedding = Tensor::stack(&new_vllm_embedding, 0)?;
205        }
206
207        Ok(vllm_embedding)
208    }
209
210    #[allow(clippy::too_many_arguments)]
211    pub fn forward(
212        &self,
213        input_ids: &Tensor,
214        pixel_values_all: Option<Vec<Vec<Tensor>>>,
215        tgt_sizes: Option<Vec<Tensor>>,
216        image_bound: Option<Vec<Tensor>>,
217        seqlen_offsets: &[usize],
218        context_lens: Vec<(usize, usize)>,
219        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
220        flash_params: &FlashParams,
221    ) -> Result<Tensor> {
222        let vllm_embedding = self.get_vllm_embedding(
223            input_ids,
224            self.llm.device(),
225            pixel_values_all,
226            tgt_sizes,
227            image_bound,
228        )?;
229
230        self.llm.forward_embed(
231            input_ids,
232            vllm_embedding,
233            seqlen_offsets,
234            context_lens,
235            metadata,
236            flash_params,
237        )
238    }
239}
240
241#[derive(Default)]
242pub(crate) struct MiniCpmOSpecificArgs {
243    pub(crate) pixel_values_all: Option<Vec<Vec<Tensor>>>,
244    pub(crate) tgt_sizes: Option<Vec<Tensor>>,
245    pub(crate) image_bound: Option<Vec<Tensor>>,
246}
247
248impl VisionModel for MiniCpmOModel {
249    fn cache(&self) -> &EitherCache {
250        self.llm.cache()
251    }
252    fn cache_mut(&mut self) -> &mut EitherCache {
253        self.llm.cache_mut()
254    }
255    fn config(&self) -> &ModelConfigMetadata {
256        self.llm.config()
257    }
258    fn device(&self) -> &Device {
259        self.llm.device()
260    }
261    fn has_conv2d(&self) -> bool {
262        true
263    }
264    fn max_seq_len(&self) -> usize {
265        self.llm.max_seq_len()
266    }
267    fn forward(
268        &self,
269        input_ids: &Tensor,
270        _pixel_values: Option<Tensor>,
271        seqlen_offsets: &[usize],
272        context_lens: Vec<(usize, usize)>,
273        _position_ids: Vec<usize>,
274        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
275        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
276        flash_params: &FlashParams,
277    ) -> Result<Tensor> {
278        let MiniCpmOSpecificArgs {
279            pixel_values_all,
280            tgt_sizes,
281            image_bound,
282        } = *model_specific_args
283            .downcast()
284            .expect("Cannot downcast into `MiniCpmOSpecificArgs`");
285        self.forward(
286            input_ids,
287            pixel_values_all,
288            tgt_sizes,
289            image_bound,
290            seqlen_offsets,
291            context_lens,
292            metadata,
293            flash_params,
294        )
295    }
296    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
297        Box::new(MiniCpmOSpecificArgs {
298            pixel_values_all: None,
299            tgt_sizes: None,
300            image_bound: None,
301        })
302    }
303}
304
305impl IsqModel for MiniCpmOModel {
306    fn get_layers(
307        &mut self,
308    ) -> (
309        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
310        &dyn DeviceMapper,
311    ) {
312        self.llm.get_layers()
313    }
314
315    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
316        let uvb = UnVarBuilder::new();
317
318        uvb.pp("llm").extend(self.llm.residual_tensors());
319        uvb.pp("vpm").extend(self.vpm.residual_tensors());
320        uvb.pp("resampler")
321            .extend(self.resampler.residual_tensors());
322
323        uvb.to_safetensors()
324    }
325
326    // NOTE: We ONLY calibrate the text bits of these models, so we should only track/return those parts!!
327
328    /// This is used for imatrix generation internally. Begin stats tracking.
329    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
330        self.llm.begin_track_stats()
331    }
332
333    /// End stats tracking and return the imatrix data
334    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
335        self.llm.extract_imatrix_data()
336    }
337}
338
339impl AnyMoeBaseModelMixin for MiniCpmOModel {}