mistralrs_core/vision_models/idefics3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod config;
4mod inputs_processor;
5mod vision;
6
7use std::any::Any;
8
9use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
10pub use config::Idefics3Config;
11pub use inputs_processor::Idefics3Processor;
12use mistralrs_quant::ShardedVarBuilder;
13use vision::{Idefics3Connector, Idefics3VisionTransformer};
14
15use crate::{
16    amoe::{AnyMoeBaseModelMixin, MlpLayer},
17    device_map::DeviceMapper,
18    models::llama::Llama,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata},
20    pipeline::{
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23    },
24    utils::unvarbuilder::UnVarBuilder,
25    AnyMoeConfig, AnyMoeExpertType,
26};
27
28pub struct Idefics3Model {
29    text_model: Llama,
30    connector: Idefics3Connector,
31    vision: Idefics3VisionTransformer,
32    config: Idefics3Config,
33    dtype: DType,
34}
35
36impl Idefics3Model {
37    pub fn new(
38        cfg: &Idefics3Config,
39        vb: ShardedVarBuilder,
40        is_gptx: bool,
41        normal_loading_metadata: NormalLoadingMetadata,
42        attention_mechanism: AttentionImplementation,
43    ) -> Result<Self> {
44        let vb_m = vb.pp("model");
45        let connector = Idefics3Connector::new(
46            cfg,
47            vb_m.pp("connector")
48                .set_dtype(DType::F32)
49                .set_device(normal_loading_metadata.real_device.clone()),
50        )?;
51        let vision = Idefics3VisionTransformer::new(
52            &cfg.vision_config,
53            vb_m.pp("vision_model")
54                .set_dtype(DType::F32)
55                .set_device(normal_loading_metadata.real_device.clone()),
56        )?;
57        let text_model = Llama::new_inner(
58            &cfg.text_config,
59            vb_m.pp("text_model"),
60            vb.pp("lm_head"),
61            is_gptx,
62            normal_loading_metadata,
63            attention_mechanism,
64        )?;
65        Ok(Self {
66            text_model,
67            connector,
68            vision,
69            config: cfg.clone(),
70            dtype: vb.dtype(),
71        })
72    }
73
74    fn inputs_merger(
75        &self,
76        input_ids: &Tensor,
77        input_embeds: &Tensor,
78        image_hidden_states: &Tensor,
79    ) -> Result<Tensor> {
80        // Docs copied from Transformers impl
81        /*
82        This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
83        The merging happens as follows:
84        - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
85        - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
86        We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
87        - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
88        - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
89        */
90        let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
91        let bs = input_ids.dim(0)?;
92        let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
93        let mut new_inputs_embeds = input_embeds.clone();
94        let reshaped_image_hidden_states =
95            image_hidden_states.reshape((bs, (), vision_hidden_size))?;
96        assert_eq!(input_embeds.dim(0)?, 1);
97        assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
98        let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
99        let mut image_hidden_state_i = 0;
100        for (i, v) in special_image_token_mask.iter().enumerate() {
101            if *v != 0 {
102                new_inputs_embeds = new_inputs_embeds.slice_assign(
103                    &[&.., &i, &..],
104                    &reshaped_image_hidden_states
105                        .i((.., image_hidden_state_i, ..))?
106                        .unsqueeze(1)?,
107                )?;
108                image_hidden_state_i += 1;
109            }
110        }
111        Ok(new_inputs_embeds)
112    }
113
114    #[allow(clippy::too_many_arguments)]
115    fn forward_inner(
116        &self,
117        input_ids: &Tensor,
118        pixel_values: Option<Tensor>,
119        seqlen_offsets: &[usize],
120        context_lens: Vec<(usize, usize)>,
121        pixel_attention_mask: Option<Tensor>,
122        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
123        flash_params: &FlashParams,
124    ) -> Result<Tensor> {
125        let input_embeds = if let Some(pixel_values) = pixel_values {
126            // == START VISUAL INPUTS INTEGRATION ==
127            let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
128            let mut s = vec![batch_size * num_images];
129            s.extend(pixel_values.dims()[2..].to_vec());
130            let pixel_values = pixel_values.reshape(s)?;
131
132            // Remove padding images which are full of 0s
133            let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
134            let real_images_inds = pixel_values
135                .eq(0.0f64)?
136                .sum(vec![
137                    pixel_values.dims().len() - 1,
138                    pixel_values.dims().len() - 2,
139                    pixel_values.dims().len() - 3,
140                ])?
141                .ne(nb_values_per_image as f64)?;
142            let mut batches = Vec::new();
143            for (batch, use_it) in pixel_values
144                .chunk(pixel_values.dim(0)?, 0)?
145                .iter()
146                .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
147            {
148                let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
149                if use_it {
150                    batches.push(batch.clone());
151                }
152            }
153            let pixel_values = Tensor::cat(&batches, 0)?;
154
155            // Vision attention mask
156            let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
157                let pixel_attention_mask = pixel_attention_mask.reshape((
158                    batch_size * num_images,
159                    pixel_attention_mask.dims()[2],
160                    pixel_attention_mask.dims()[3],
161                ))?;
162                let mut batches = Vec::new();
163                for (batch, use_it) in pixel_attention_mask
164                    .chunk(pixel_attention_mask.dim(0)?, 0)?
165                    .iter()
166                    .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
167                {
168                    let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
169                    if use_it {
170                        batches.push(batch.clone());
171                    }
172                }
173                Tensor::cat(&batches, 0)?
174            } else {
175                Tensor::ones(
176                    (
177                        pixel_values.dims()[0],
178                        pixel_values.dims()[2],
179                        pixel_values.dims()[3],
180                    ),
181                    DType::U8,
182                    pixel_values.device(),
183                )?
184            };
185
186            let patch_size = self.config.vision_config.patch_size;
187            let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
188            let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
189
190            let patch_attention_mask = patches_subgrid
191                .sum((D::Minus1, D::Minus2))?
192                .gt(0.0)?
193                .to_dtype(DType::U8)?;
194
195            let pixel_values = pixel_values.to_dtype(self.dtype)?;
196
197            // Get seq from vision encoder
198            let image_hidden_states = self.vision.forward(
199                &pixel_values.to_dtype(DType::F32)?,
200                Some(&patch_attention_mask),
201            )?;
202
203            // Modality proj and perceiver resampling
204            let image_hidden_states = self.connector.forward(&image_hidden_states)?;
205
206            if self.text_model.cache().normal().0[0].current_seq_len() == 0 {
207                self.inputs_merger(
208                    input_ids,
209                    &self
210                        .text_model
211                        .get_input_embeddings(input_ids)?
212                        .to_dtype(DType::F32)?,
213                    &image_hidden_states,
214                )?
215                .to_dtype(self.dtype)?
216            } else {
217                candle_core::bail!("Pixel values were specified for a non-prompt.")
218            }
219        } else {
220            self.text_model.get_input_embeddings(input_ids)?
221        };
222
223        self.text_model.forward_embeds(
224            input_ids,
225            input_embeds,
226            seqlen_offsets,
227            context_lens,
228            metadata,
229            flash_params,
230        )
231    }
232}
233
234impl IsqModel for Idefics3Model {
235    fn get_layers(
236        &mut self,
237    ) -> (
238        Vec<(
239            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
240            Option<usize>,
241        )>,
242        &dyn DeviceMapper,
243    ) {
244        self.text_model.get_layers()
245    }
246
247    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
248        let uvb = UnVarBuilder::new();
249
250        let uvb_m = uvb.pp("model");
251        uvb_m
252            .pp("connector")
253            .pp("modality_projection")
254            .pp("proj")
255            .add(&self.connector.modality_projection.proj);
256        uvb.extend(self.text_model.residual_tensors_m(uvb_m.pp("text_model")));
257        uvb_m
258            .pp("vision_model")
259            .extend(self.vision.residual_tensors());
260
261        uvb.to_safetensors()
262    }
263}
264
265// AnyMoE is forwarded to the base model
266impl AnyMoeBaseModelMixin for Idefics3Model {
267    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
268        self.text_model.get_mlps()
269    }
270    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
271        self.text_model.get_mlps_mut()
272    }
273    fn create_anymoe_layers(
274        &mut self,
275        additional_vbs: Vec<ShardedVarBuilder>,
276        config: AnyMoeConfig,
277        (prefix, mlp): (String, String),
278        layers: Vec<usize>,
279        expert_type: AnyMoeExpertType,
280        gate_vb: Option<ShardedVarBuilder>,
281    ) -> Result<()> {
282        self.text_model.create_anymoe_layers(
283            additional_vbs,
284            config,
285            (prefix, mlp),
286            layers,
287            expert_type,
288            gate_vb,
289        )
290    }
291    fn amoe_supported(&self) -> bool {
292        true
293    }
294}
295
296impl VisionModel for Idefics3Model {
297    fn forward(
298        &self,
299        input_ids: &Tensor,
300        pixel_values: Option<Tensor>,
301        seqlen_offsets: &[usize],
302        context_lens: Vec<(usize, usize)>,
303        _: Vec<usize>, // Ignore, it is for phi3
304        model_specific_args: Box<dyn Any>,
305        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
306        flash_params: &FlashParams,
307    ) -> candle_core::Result<Tensor> {
308        let pixel_attention_mask: Option<Tensor> = *model_specific_args
309            .downcast()
310            .expect("Cannot downcast into `Option<Tensor>`");
311        self.forward_inner(
312            input_ids,
313            pixel_values,
314            seqlen_offsets,
315            context_lens,
316            pixel_attention_mask,
317            metadata,
318            flash_params,
319        )
320    }
321    fn cache(&self) -> &EitherCache {
322        self.text_model.cache()
323    }
324    fn cache_mut(&mut self) -> &mut EitherCache {
325        self.text_model.cache_mut()
326    }
327    fn device(&self) -> &Device {
328        self.text_model.device()
329    }
330    fn max_seq_len(&self) -> usize {
331        self.text_model.max_seq_len()
332    }
333    fn has_conv2d(&self) -> bool {
334        true
335    }
336    fn config(&self) -> &ModelConfigMetadata {
337        self.text_model.config()
338    }
339    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
340        let args: Option<Tensor> = None;
341        Box::new(args)
342    }
343}