mistralrs_core/vision_models/idefics3/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod config;
4mod inputs_processor;
5mod vision;
6
7use std::any::Any;
8
9use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
10pub use config::Idefics3Config;
11pub use inputs_processor::Idefics3Processor;
12use mistralrs_quant::ShardedVarBuilder;
13use vision::{Idefics3Connector, Idefics3VisionTransformer};
14
15use crate::{
16 amoe::{AnyMoeBaseModelMixin, MlpLayer},
17 device_map::DeviceMapper,
18 models::llama::Llama,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata},
20 pipeline::{
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23 },
24 utils::unvarbuilder::UnVarBuilder,
25 AnyMoeConfig, AnyMoeExpertType,
26};
27
28pub struct Idefics3Model {
29 text_model: Llama,
30 connector: Idefics3Connector,
31 vision: Idefics3VisionTransformer,
32 config: Idefics3Config,
33 dtype: DType,
34}
35
36impl Idefics3Model {
37 pub fn new(
38 cfg: &Idefics3Config,
39 vb: ShardedVarBuilder,
40 is_gptx: bool,
41 normal_loading_metadata: NormalLoadingMetadata,
42 attention_mechanism: AttentionImplementation,
43 ) -> Result<Self> {
44 let vb_m = vb.pp("model");
45 let connector = Idefics3Connector::new(
46 cfg,
47 vb_m.pp("connector")
48 .set_dtype(DType::F32)
49 .set_device(normal_loading_metadata.real_device.clone()),
50 )?;
51 let vision = Idefics3VisionTransformer::new(
52 &cfg.vision_config,
53 vb_m.pp("vision_model")
54 .set_dtype(DType::F32)
55 .set_device(normal_loading_metadata.real_device.clone()),
56 )?;
57 let text_model = Llama::new_inner(
58 &cfg.text_config,
59 vb_m.pp("text_model"),
60 vb.pp("lm_head"),
61 is_gptx,
62 normal_loading_metadata,
63 attention_mechanism,
64 )?;
65 Ok(Self {
66 text_model,
67 connector,
68 vision,
69 config: cfg.clone(),
70 dtype: vb.dtype(),
71 })
72 }
73
74 fn inputs_merger(
75 &self,
76 input_ids: &Tensor,
77 input_embeds: &Tensor,
78 image_hidden_states: &Tensor,
79 ) -> Result<Tensor> {
80 let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
91 let bs = input_ids.dim(0)?;
92 let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
93 let mut new_inputs_embeds = input_embeds.clone();
94 let reshaped_image_hidden_states =
95 image_hidden_states.reshape((bs, (), vision_hidden_size))?;
96 assert_eq!(input_embeds.dim(0)?, 1);
97 assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
98 let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
99 let mut image_hidden_state_i = 0;
100 for (i, v) in special_image_token_mask.iter().enumerate() {
101 if *v != 0 {
102 new_inputs_embeds = new_inputs_embeds.slice_assign(
103 &[&.., &i, &..],
104 &reshaped_image_hidden_states
105 .i((.., image_hidden_state_i, ..))?
106 .unsqueeze(1)?,
107 )?;
108 image_hidden_state_i += 1;
109 }
110 }
111 Ok(new_inputs_embeds)
112 }
113
114 #[allow(clippy::too_many_arguments)]
115 fn forward_inner(
116 &self,
117 input_ids: &Tensor,
118 pixel_values: Option<Tensor>,
119 seqlen_offsets: &[usize],
120 context_lens: Vec<(usize, usize)>,
121 pixel_attention_mask: Option<Tensor>,
122 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
123 flash_params: &FlashParams,
124 ) -> Result<Tensor> {
125 let input_embeds = if let Some(pixel_values) = pixel_values {
126 let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
128 let mut s = vec![batch_size * num_images];
129 s.extend(pixel_values.dims()[2..].to_vec());
130 let pixel_values = pixel_values.reshape(s)?;
131
132 let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
134 let real_images_inds = pixel_values
135 .eq(0.0f64)?
136 .sum(vec![
137 pixel_values.dims().len() - 1,
138 pixel_values.dims().len() - 2,
139 pixel_values.dims().len() - 3,
140 ])?
141 .ne(nb_values_per_image as f64)?;
142 let mut batches = Vec::new();
143 for (batch, use_it) in pixel_values
144 .chunk(pixel_values.dim(0)?, 0)?
145 .iter()
146 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
147 {
148 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
149 if use_it {
150 batches.push(batch.clone());
151 }
152 }
153 let pixel_values = Tensor::cat(&batches, 0)?;
154
155 let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
157 let pixel_attention_mask = pixel_attention_mask.reshape((
158 batch_size * num_images,
159 pixel_attention_mask.dims()[2],
160 pixel_attention_mask.dims()[3],
161 ))?;
162 let mut batches = Vec::new();
163 for (batch, use_it) in pixel_attention_mask
164 .chunk(pixel_attention_mask.dim(0)?, 0)?
165 .iter()
166 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
167 {
168 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
169 if use_it {
170 batches.push(batch.clone());
171 }
172 }
173 Tensor::cat(&batches, 0)?
174 } else {
175 Tensor::ones(
176 (
177 pixel_values.dims()[0],
178 pixel_values.dims()[2],
179 pixel_values.dims()[3],
180 ),
181 DType::U8,
182 pixel_values.device(),
183 )?
184 };
185
186 let patch_size = self.config.vision_config.patch_size;
187 let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
188 let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
189
190 let patch_attention_mask = patches_subgrid
191 .sum((D::Minus1, D::Minus2))?
192 .gt(0.0)?
193 .to_dtype(DType::U8)?;
194
195 let pixel_values = pixel_values.to_dtype(self.dtype)?;
196
197 let image_hidden_states = self.vision.forward(
199 &pixel_values.to_dtype(DType::F32)?,
200 Some(&patch_attention_mask),
201 )?;
202
203 let image_hidden_states = self.connector.forward(&image_hidden_states)?;
205
206 if self.text_model.cache().normal().0[0].current_seq_len() == 0 {
207 self.inputs_merger(
208 input_ids,
209 &self
210 .text_model
211 .get_input_embeddings(input_ids)?
212 .to_dtype(DType::F32)?,
213 &image_hidden_states,
214 )?
215 .to_dtype(self.dtype)?
216 } else {
217 candle_core::bail!("Pixel values were specified for a non-prompt.")
218 }
219 } else {
220 self.text_model.get_input_embeddings(input_ids)?
221 };
222
223 self.text_model.forward_embeds(
224 input_ids,
225 input_embeds,
226 seqlen_offsets,
227 context_lens,
228 metadata,
229 flash_params,
230 )
231 }
232}
233
234impl IsqModel for Idefics3Model {
235 fn get_layers(
236 &mut self,
237 ) -> (
238 Vec<(
239 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
240 Option<usize>,
241 )>,
242 &dyn DeviceMapper,
243 ) {
244 self.text_model.get_layers()
245 }
246
247 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
248 let uvb = UnVarBuilder::new();
249
250 let uvb_m = uvb.pp("model");
251 uvb_m
252 .pp("connector")
253 .pp("modality_projection")
254 .pp("proj")
255 .add(&self.connector.modality_projection.proj);
256 uvb.extend(self.text_model.residual_tensors_m(uvb_m.pp("text_model")));
257 uvb_m
258 .pp("vision_model")
259 .extend(self.vision.residual_tensors());
260
261 uvb.to_safetensors()
262 }
263}
264
265impl AnyMoeBaseModelMixin for Idefics3Model {
267 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
268 self.text_model.get_mlps()
269 }
270 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
271 self.text_model.get_mlps_mut()
272 }
273 fn create_anymoe_layers(
274 &mut self,
275 additional_vbs: Vec<ShardedVarBuilder>,
276 config: AnyMoeConfig,
277 (prefix, mlp): (String, String),
278 layers: Vec<usize>,
279 expert_type: AnyMoeExpertType,
280 gate_vb: Option<ShardedVarBuilder>,
281 ) -> Result<()> {
282 self.text_model.create_anymoe_layers(
283 additional_vbs,
284 config,
285 (prefix, mlp),
286 layers,
287 expert_type,
288 gate_vb,
289 )
290 }
291 fn amoe_supported(&self) -> bool {
292 true
293 }
294}
295
296impl VisionModel for Idefics3Model {
297 fn forward(
298 &self,
299 input_ids: &Tensor,
300 pixel_values: Option<Tensor>,
301 seqlen_offsets: &[usize],
302 context_lens: Vec<(usize, usize)>,
303 _: Vec<usize>, model_specific_args: Box<dyn Any>,
305 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
306 flash_params: &FlashParams,
307 ) -> candle_core::Result<Tensor> {
308 let pixel_attention_mask: Option<Tensor> = *model_specific_args
309 .downcast()
310 .expect("Cannot downcast into `Option<Tensor>`");
311 self.forward_inner(
312 input_ids,
313 pixel_values,
314 seqlen_offsets,
315 context_lens,
316 pixel_attention_mask,
317 metadata,
318 flash_params,
319 )
320 }
321 fn cache(&self) -> &EitherCache {
322 self.text_model.cache()
323 }
324 fn cache_mut(&mut self) -> &mut EitherCache {
325 self.text_model.cache_mut()
326 }
327 fn device(&self) -> &Device {
328 self.text_model.device()
329 }
330 fn max_seq_len(&self) -> usize {
331 self.text_model.max_seq_len()
332 }
333 fn has_conv2d(&self) -> bool {
334 true
335 }
336 fn config(&self) -> &ModelConfigMetadata {
337 self.text_model.config()
338 }
339 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
340 let args: Option<Tensor> = None;
341 Box::new(args)
342 }
343}