mistralrs_core/vision_models/mllama/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod config;
4mod inputs_processor;
5mod text;
6mod vision;
7
8use std::{any::Any, collections::HashMap, sync::Arc};
9
10pub(crate) use config::{MLlamaConfig, MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig};
11use config::{MLlamaVisionConfig, VisionActivation};
12pub(crate) use inputs_processor::MLlamaProcessor;
13use text::MLlamaTextModel;
14use vision::MLlamaVisionModel;
15
16use candle_core::{DType, Device, Result, Tensor, D};
17use candle_nn::{Linear, Module};
18use mistralrs_quant::{CollectedImatrixData, QuantMethod, ShardedVarBuilder};
19
20use crate::{
21    amoe::AnyMoeBaseModelMixin,
22    device_map::DeviceMapper,
23    layers::{linear, GetFloatInfo},
24    layers_masker::masked_fill,
25    ops::RepeatInterleaveOp,
26    paged_attention::{AttentionImplementation, ModelConfigMetadata},
27    pipeline::{
28        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
29        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
30    },
31    utils::unvarbuilder::UnVarBuilder,
32};
33
34// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L99
35fn prepare_cross_attention_mask(
36    cross_attention_mask: &Tensor,
37    num_vision_tokens: usize,
38    dtype: DType,
39) -> Result<(Tensor, Tensor)> {
40    let bs = cross_attention_mask.dim(0)?;
41    let text_total_length = cross_attention_mask.dim(1)?;
42    let mut cross_attn_mask = cross_attention_mask
43        .to_dtype(DType::F32)?
44        .repeat_interleave(num_vision_tokens, 3)?;
45    cross_attn_mask = cross_attn_mask.reshape((bs, text_total_length, ()))?;
46    cross_attn_mask = cross_attn_mask.unsqueeze(1)?;
47
48    // Invert the mask
49    let inverted_cross_attn_mask = (1. - cross_attn_mask)?;
50    let neg_inf_value = dtype.finfo()?.min;
51    cross_attn_mask = masked_fill(
52        &inverted_cross_attn_mask,
53        &inverted_cross_attn_mask.ne(0.)?,
54        neg_inf_value as f32,
55    )?;
56
57    // Apply full-row bias which return 4d tensor of shape (b, h, s1, 1) where
58    // value is 0 if a full row in cross attn mask's last dimension contains
59    // negative infinity values, otherwise it's 1
60    let full_text_row_masked_out_mask = cross_attn_mask
61        .ne(neg_inf_value)?
62        .sum(D::Minus1)?
63        .ne(0.)?
64        .unsqueeze(D::Minus1)?;
65
66    cross_attn_mask = cross_attn_mask
67        .broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?
68        .to_dtype(DType::F32)?
69        .to_dtype(dtype)?;
70
71    Ok((cross_attn_mask, full_text_row_masked_out_mask))
72}
73
74pub(crate) struct MLlamaModel {
75    vision_model: MLlamaVisionModel,
76    language_model: MLlamaTextModel,
77    multi_modal_projector: Linear,
78    hidden_size: usize,
79    dtype: DType,
80}
81
82impl MLlamaModel {
83    pub(crate) fn new(
84        cfg: &MLlamaConfig,
85        vb: ShardedVarBuilder,
86        is_gptx: bool,
87        normal_loading_metadata: NormalLoadingMetadata,
88        attention_mechanism: AttentionImplementation,
89    ) -> Result<Self> {
90        let real_dev = normal_loading_metadata.real_device.clone();
91        Ok(Self {
92            vision_model: MLlamaVisionModel::new(
93                &cfg.vision_config,
94                vb.pp("vision_model"),
95                &real_dev,
96                &normal_loading_metadata.mapper.get_comm_for(0)?,
97            )?,
98            language_model: MLlamaTextModel::new(
99                &cfg.text_config,
100                vb.pp("language_model"),
101                is_gptx,
102                normal_loading_metadata,
103                attention_mechanism,
104            )?,
105            multi_modal_projector: linear(
106                cfg.vision_config.vision_output_dim,
107                cfg.text_config.hidden_size,
108                vb.pp("multi_modal_projector").set_device(real_dev.clone()),
109            )?,
110            hidden_size: cfg.text_config.hidden_size,
111            dtype: vb.dtype(),
112        })
113    }
114
115    #[allow(clippy::too_many_arguments)]
116    fn forward_inner(
117        &self,
118        input_ids: &Tensor,
119        pixel_values: Option<&Tensor>,
120        aspect_ratio_mask: Option<&Tensor>,
121        aspect_ratio_ids: Option<&Tensor>,
122        cross_attn_mask: Option<&Tensor>,
123        seqlen_offsets: &[usize],
124        context_lens: Vec<(usize, usize)>,
125    ) -> Result<Tensor> {
126        let cross_attn_states = if let Some(pixel_values) = pixel_values {
127            let Some(aspect_ratio_mask) = aspect_ratio_mask else {
128                candle_core::bail!("`aspect_ratio_mask` must be specified if `pixel_values` is.");
129            };
130            let Some(aspect_ratio_ids) = aspect_ratio_ids else {
131                candle_core::bail!("`aspect_ratio_ids` must be specified if `pixel_values` is.");
132            };
133            let vision_outputs =
134                self.vision_model
135                    .forward(pixel_values, aspect_ratio_ids, aspect_ratio_mask)?;
136            let cross_attention_states = self
137                .multi_modal_projector
138                .forward(&vision_outputs.flatten(0, 1)?)?
139                .reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?
140                .to_dtype(self.dtype)?;
141            Some(cross_attention_states)
142        } else {
143            None
144        };
145
146        let (cross_attn_mask, full_text_row_masked_out_mask) =
147            if let Some(cross_attn_mask) = cross_attn_mask {
148                let (mut cmask, fmask) = prepare_cross_attention_mask(
149                    cross_attn_mask,
150                    self.vision_model.num_patches,
151                    self.dtype,
152                )?;
153                cmask = cmask.squeeze(1)?;
154                (Some(cmask), Some(fmask))
155            } else {
156                (None, None)
157            };
158
159        self.language_model.forward(
160            input_ids,
161            cross_attn_states.as_ref(),
162            cross_attn_mask.as_ref(),
163            full_text_row_masked_out_mask.as_ref(),
164            seqlen_offsets,
165            context_lens,
166        )
167    }
168}
169
170#[derive(Default)]
171pub(crate) struct MLlamaSpecificArgs {
172    pub aspect_ratio_ids: Option<Tensor>,
173    pub aspect_ratio_mask: Option<Tensor>,
174    pub cross_attn_mask: Option<Tensor>,
175}
176
177impl VisionModel for MLlamaModel {
178    fn cache(&self) -> &EitherCache {
179        &self.language_model.cache
180    }
181    fn cache_mut(&mut self) -> &mut EitherCache {
182        &mut self.language_model.cache
183    }
184    fn config(&self) -> &ModelConfigMetadata {
185        &self.language_model.cfg
186    }
187    fn device(&self) -> &Device {
188        &self.language_model.device
189    }
190    fn max_seq_len(&self) -> usize {
191        self.language_model.max_position_embeddings
192    }
193    fn forward(
194        &self,
195        input_ids: &Tensor,
196        pixel_values: Option<Tensor>,
197        seqlen_offsets: &[usize],
198        context_lens: Vec<(usize, usize)>,
199        _position_ids: Vec<usize>,
200        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
201        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
202        _flash_params: &FlashParams,
203    ) -> Result<Tensor> {
204        let MLlamaSpecificArgs {
205            aspect_ratio_ids,
206            aspect_ratio_mask,
207            cross_attn_mask,
208        } = *model_specific_args
209            .downcast()
210            .expect("Cannot downcast into `MLlamaSpecificArgs`");
211        self.forward_inner(
212            input_ids,
213            pixel_values.as_ref(),
214            aspect_ratio_mask.as_ref(),
215            aspect_ratio_ids.as_ref(),
216            cross_attn_mask.as_ref(),
217            seqlen_offsets,
218            context_lens,
219        )
220    }
221    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
222        Box::new(MLlamaSpecificArgs::default())
223    }
224}
225
226impl IsqModel for MLlamaModel {
227    fn get_layers(
228        &mut self,
229    ) -> (
230        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
231        &dyn DeviceMapper,
232    ) {
233        let (mut layers, mapper) = self.language_model.get_layers();
234        layers.extend(
235            self.vision_model
236                .get_isq_layers()
237                .into_iter()
238                .map(|layer| (layer, None)),
239        );
240        (layers, mapper)
241    }
242
243    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
244        let uvb = UnVarBuilder::new();
245
246        uvb.pp("multi_modal_projector")
247            .add(&self.multi_modal_projector);
248        uvb.pp("language_model")
249            .extend(self.language_model.residual_tensors());
250        uvb.pp("vision_model")
251            .extend(self.vision_model.residual_tensors());
252
253        uvb.to_safetensors()
254    }
255
256    // NOTE: We ONLY calibrate the text bits of these models, so we should only track/return those parts!!
257
258    /// This is used for imatrix generation internally. Begin stats tracking.
259    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
260        let layers = self
261            .language_model
262            .get_layers()
263            .0
264            .into_iter()
265            .map(|(layer, _)| layer)
266            .collect::<Vec<_>>();
267        for layer in layers {
268            Arc::get_mut(layer).unwrap().begin_track_stats()?;
269        }
270        Ok(())
271    }
272
273    /// End stats tracking and return the imatrix data
274    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
275        let layers = self
276            .language_model
277            .get_layers()
278            .0
279            .into_iter()
280            .enumerate()
281            .map(|(i, (layer, _))| (i, layer))
282            .collect::<Vec<_>>();
283        let mut data = HashMap::new();
284        for (i, layer) in layers {
285            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
286        }
287        Ok(CollectedImatrixData(data))
288    }
289}
290
291impl AnyMoeBaseModelMixin for MLlamaModel {}