mistralrs_core/vision_models/minicpmo/
mod.rs1use std::{any::Any, sync::Arc};
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4pub use config::MiniCpmOConfig;
5pub use inputs_processor::MiniCpmOProcessor;
6use mistralrs_quant::{CollectedImatrixData, QuantMethod, ShardedVarBuilder};
7use resampler::Resampler;
8
9use crate::{
10 amoe::AnyMoeBaseModelMixin,
11 device_map::DeviceMapper,
12 models::qwen2,
13 paged_attention::{AttentionImplementation, ModelConfigMetadata},
14 pipeline::{
15 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
17 },
18 utils::unvarbuilder::UnVarBuilder,
19};
20
21use self::siglip::SiglipVisionTransformer;
22
23use super::siglip;
24
25mod config;
26mod inputs_processor;
27mod resampler;
28
29pub struct MiniCpmOModel {
30 cfg: MiniCpmOConfig,
31 llm: qwen2::Model,
32 vpm: SiglipVisionTransformer,
33 resampler: Resampler,
34}
35
36impl MiniCpmOModel {
37 pub fn new(
38 cfg: &MiniCpmOConfig,
39 vb: ShardedVarBuilder,
40 is_gptx: bool,
41 normal_loading_metadata: NormalLoadingMetadata,
42 attention_mechanism: AttentionImplementation,
43 ) -> Result<Self> {
44 let real_device = normal_loading_metadata.real_device.clone();
45 let llm = qwen2::Model::new(
46 &cfg.text_config,
47 vb.pp("llm"),
48 is_gptx,
49 normal_loading_metadata,
50 attention_mechanism,
51 )?;
52 let vpm = SiglipVisionTransformer::new(
53 &cfg.vision_config,
54 vb.pp("vpm").set_device(real_device.clone()),
55 )?;
56 let resampler = Resampler::new(
57 cfg.query_num,
58 cfg.text_config.hidden_size,
59 cfg.text_config.hidden_size / 128,
60 cfg.vision_config.hidden_size,
61 true,
62 None,
63 vb.pp("resampler").set_device(real_device.clone()),
64 )?;
65 Ok(Self {
66 cfg: cfg.clone(),
67 llm,
68 vpm,
69 resampler,
70 })
71 }
72
73 fn get_vllm_embedding(
74 &self,
75 input_ids: &Tensor,
76 device: &Device,
77 pixel_values_all: Option<Vec<Vec<Tensor>>>,
78 tgt_sizes: Option<Vec<Tensor>>,
79 image_bound: Option<Vec<Tensor>>,
80 ) -> Result<Tensor> {
81 let mut vllm_embedding = self.llm.get_input_embeddings(input_ids)?;
82
83 if let Some(pixel_values_all) = pixel_values_all {
84 let tgt_sizes_all = tgt_sizes.as_ref().expect("Need tgt_sizes");
85 let image_bound = image_bound.expect("Need image_bound");
86 let image_bound_vec = image_bound
87 .into_iter()
88 .map(|x| x.to_vec2::<u32>())
89 .collect::<Result<Vec<_>>>()?;
90
91 let mut all_pixel_values = Vec::new();
92 let mut img_cnts = Vec::new();
93 for pixel_values in &pixel_values_all {
94 img_cnts.push(pixel_values.len());
95 let mut imgs = Vec::new();
96 for i in pixel_values {
97 imgs.push(i.flatten_to(1)?.permute((1, 0))?);
99 }
100 all_pixel_values.extend(imgs);
101 }
102
103 let tgt_sizes = Tensor::cat(tgt_sizes_all, 0)?;
104 let tgt_sizes_vec = tgt_sizes.to_vec2::<u32>()?;
105
106 let max_patches = (tgt_sizes.i((.., 0))? * tgt_sizes.i((.., 1))?)?
107 .max(0)?
108 .to_scalar::<u32>()? as usize;
109
110 let lens = all_pixel_values
112 .iter()
113 .map(|pixel_values| pixel_values.dim(0))
114 .collect::<Result<Vec<_>>>()?;
115 let max_len = lens.into_iter().max().expect("No pixel values somehow?");
116 all_pixel_values = all_pixel_values
117 .into_iter()
118 .map(|pixel_values| {
119 pixel_values.pad_with_zeros(0, 0, max_len - pixel_values.dim(0)?)
120 })
121 .collect::<Result<Vec<_>>>()?;
122 let mut all_pixel_values = Tensor::stack(&all_pixel_values, 0)?;
123
124 let (b, l, _) = all_pixel_values.dims3()?;
125 all_pixel_values = all_pixel_values
126 .permute((0, 2, 1))?
127 .reshape((b, 3, (), l))?;
128
129 let mut patch_attn_mask = Tensor::zeros((b, 1, max_patches), DType::U8, device)?;
130 for (i, tgt_sizes_vec_i) in tgt_sizes_vec.iter().enumerate().take(b) {
131 let n = (tgt_sizes_vec_i[0] * tgt_sizes_vec_i[1]) as usize;
132 patch_attn_mask = patch_attn_mask.slice_assign(
133 &[&i, &0, &(..n)],
134 &Tensor::ones((1, 1, n), DType::U8, device)?,
135 )?;
136 }
137
138 let vision_batch_size = self.cfg.vision_batch_size;
139 all_pixel_values = all_pixel_values.to_dtype(self.llm.embed_dtype())?;
140
141 let mut vision_embedding = if b > vision_batch_size {
142 let mut hs = Vec::new();
143 for i in (0..b).step_by(vision_batch_size) {
144 let start_idx = i;
145 let end_idx = i + vision_batch_size;
146 let tmp_hs = self.vpm.forward(
147 &all_pixel_values.i(start_idx..end_idx)?,
148 Some(&patch_attn_mask.i(start_idx..end_idx)?),
149 Some(&tgt_sizes.i(start_idx..end_idx)?),
150 )?;
151 hs.push(tmp_hs);
152 }
153 Tensor::cat(&hs, 0)?
154 } else {
155 self.vpm
156 .forward(&all_pixel_values, Some(&patch_attn_mask), Some(&tgt_sizes))?
157 };
158 vision_embedding = self.resampler.forward(&vision_embedding, &tgt_sizes_vec)?;
159
160 let mut start = 0;
161 let mut vision_hidden_states = Vec::new();
162 for pixel_values in &pixel_values_all {
163 let img_cnt = pixel_values.len();
164 if img_cnt > 0 {
165 vision_hidden_states.push(Some(
166 vision_embedding
167 .i(start..start + img_cnt)?
168 .to_dtype(vllm_embedding.dtype())?,
169 ));
170 start += img_cnt;
171 } else {
172 vision_hidden_states.push(None);
173 }
174 }
175
176 let mut new_vllm_embedding = Vec::new();
177 for i in 0..input_ids.dim(0)? {
178 if let Some(cur_vs_hs) = &vision_hidden_states[i] {
179 let mut cur_vllm_emb = vllm_embedding.i(i)?;
180 let cur_image_bound = &image_bound_vec[i];
181 if !cur_image_bound.is_empty() {
182 let mut image_indices = Vec::new();
183 for r in cur_image_bound {
184 image_indices.push(Tensor::arange(r[0], r[1], device)?);
185 }
186 let image_indices = Tensor::stack(&image_indices, 0)?;
187
188 let indices = image_indices
189 .reshape(((), 1))?
190 .repeat((1, cur_vllm_emb.dim(D::Minus1)?))?;
191 let cur_vllm_emb_neg = cur_vllm_emb.gather(&indices, 0)?.neg()?;
193 cur_vllm_emb = cur_vllm_emb.scatter_add(&indices, &cur_vllm_emb_neg, 0)?;
194 cur_vllm_emb = cur_vllm_emb.scatter_add(
196 &indices,
197 &cur_vs_hs.reshape(((), cur_vs_hs.dim(D::Minus1)?))?,
198 0,
199 )?;
200 new_vllm_embedding.push(cur_vllm_emb);
201 }
202 }
203 }
204 vllm_embedding = Tensor::stack(&new_vllm_embedding, 0)?;
205 }
206
207 Ok(vllm_embedding)
208 }
209
210 #[allow(clippy::too_many_arguments)]
211 pub fn forward(
212 &self,
213 input_ids: &Tensor,
214 pixel_values_all: Option<Vec<Vec<Tensor>>>,
215 tgt_sizes: Option<Vec<Tensor>>,
216 image_bound: Option<Vec<Tensor>>,
217 seqlen_offsets: &[usize],
218 context_lens: Vec<(usize, usize)>,
219 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
220 flash_params: &FlashParams,
221 ) -> Result<Tensor> {
222 let vllm_embedding = self.get_vllm_embedding(
223 input_ids,
224 self.llm.device(),
225 pixel_values_all,
226 tgt_sizes,
227 image_bound,
228 )?;
229
230 self.llm.forward_embed(
231 input_ids,
232 vllm_embedding,
233 seqlen_offsets,
234 context_lens,
235 metadata,
236 flash_params,
237 )
238 }
239}
240
241#[derive(Default)]
242pub(crate) struct MiniCpmOSpecificArgs {
243 pub(crate) pixel_values_all: Option<Vec<Vec<Tensor>>>,
244 pub(crate) tgt_sizes: Option<Vec<Tensor>>,
245 pub(crate) image_bound: Option<Vec<Tensor>>,
246}
247
248impl VisionModel for MiniCpmOModel {
249 fn cache(&self) -> &EitherCache {
250 self.llm.cache()
251 }
252 fn cache_mut(&mut self) -> &mut EitherCache {
253 self.llm.cache_mut()
254 }
255 fn config(&self) -> &ModelConfigMetadata {
256 self.llm.config()
257 }
258 fn device(&self) -> &Device {
259 self.llm.device()
260 }
261 fn has_conv2d(&self) -> bool {
262 true
263 }
264 fn max_seq_len(&self) -> usize {
265 self.llm.max_seq_len()
266 }
267 fn forward(
268 &self,
269 input_ids: &Tensor,
270 _pixel_values: Option<Tensor>,
271 seqlen_offsets: &[usize],
272 context_lens: Vec<(usize, usize)>,
273 _position_ids: Vec<usize>,
274 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
276 flash_params: &FlashParams,
277 ) -> Result<Tensor> {
278 let MiniCpmOSpecificArgs {
279 pixel_values_all,
280 tgt_sizes,
281 image_bound,
282 } = *model_specific_args
283 .downcast()
284 .expect("Cannot downcast into `MiniCpmOSpecificArgs`");
285 self.forward(
286 input_ids,
287 pixel_values_all,
288 tgt_sizes,
289 image_bound,
290 seqlen_offsets,
291 context_lens,
292 metadata,
293 flash_params,
294 )
295 }
296 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
297 Box::new(MiniCpmOSpecificArgs {
298 pixel_values_all: None,
299 tgt_sizes: None,
300 image_bound: None,
301 })
302 }
303}
304
305impl IsqModel for MiniCpmOModel {
306 fn get_layers(
307 &mut self,
308 ) -> (
309 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
310 &dyn DeviceMapper,
311 ) {
312 self.llm.get_layers()
313 }
314
315 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
316 let uvb = UnVarBuilder::new();
317
318 uvb.pp("llm").extend(self.llm.residual_tensors());
319 uvb.pp("vpm").extend(self.vpm.residual_tensors());
320 uvb.pp("resampler")
321 .extend(self.resampler.residual_tensors());
322
323 uvb.to_safetensors()
324 }
325
326 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
330 self.llm.begin_track_stats()
331 }
332
333 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
335 self.llm.extract_imatrix_data()
336 }
337}
338
339impl AnyMoeBaseModelMixin for MiniCpmOModel {}