mistralrs_core/vision_models/mllama/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod config;
4mod inputs_processor;
5mod text;
6mod vision;
7
8use std::{any::Any, collections::HashMap, sync::Arc};
9
10pub(crate) use config::{MLlamaConfig, MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig};
11use config::{MLlamaVisionConfig, VisionActivation};
12pub(crate) use inputs_processor::MLlamaProcessor;
13use text::MLlamaTextModel;
14use vision::MLlamaVisionModel;
15
16use candle_core::{DType, Device, Result, Tensor, D};
17use candle_nn::{Linear, Module};
18use mistralrs_quant::{CollectedImatrixData, QuantMethod, ShardedVarBuilder};
19
20use crate::{
21    amoe::AnyMoeBaseModelMixin,
22    device_map::DeviceMapper,
23    layers::{linear, GetFloatInfo},
24    layers_masker::masked_fill,
25    ops::RepeatInterleaveOp,
26    paged_attention::{AttentionImplementation, ModelConfigMetadata},
27    pipeline::{
28        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
29        EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
30    },
31    utils::unvarbuilder::UnVarBuilder,
32};
33
34// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L99
35fn prepare_cross_attention_mask(
36    cross_attention_mask: &Tensor,
37    num_vision_tokens: usize,
38    dtype: DType,
39) -> Result<(Tensor, Tensor)> {
40    let bs = cross_attention_mask.dim(0)?;
41    let text_total_length = cross_attention_mask.dim(1)?;
42    let mut cross_attn_mask = cross_attention_mask
43        .to_dtype(DType::F32)?
44        .repeat_interleave(num_vision_tokens, 3)?;
45    cross_attn_mask = cross_attn_mask.reshape((bs, text_total_length, ()))?;
46    cross_attn_mask = cross_attn_mask.unsqueeze(1)?;
47
48    // Invert the mask
49    let inverted_cross_attn_mask = (1. - cross_attn_mask)?;
50    let neg_inf_value = dtype.finfo()?.min;
51    cross_attn_mask = masked_fill(
52        &inverted_cross_attn_mask,
53        &inverted_cross_attn_mask.ne(0.)?,
54        neg_inf_value as f32,
55    )?;
56
57    // Apply full-row bias which return 4d tensor of shape (b, h, s1, 1) where
58    // value is 0 if a full row in cross attn mask's last dimension contains
59    // negative infinity values, otherwise it's 1
60    let full_text_row_masked_out_mask = cross_attn_mask
61        .ne(neg_inf_value)?
62        .sum(D::Minus1)?
63        .ne(0.)?
64        .unsqueeze(D::Minus1)?;
65
66    cross_attn_mask = cross_attn_mask
67        .broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?
68        .to_dtype(DType::F32)?
69        .to_dtype(dtype)?;
70
71    Ok((cross_attn_mask, full_text_row_masked_out_mask))
72}
73
74pub(crate) struct MLlamaModel {
75    vision_model: MLlamaVisionModel,
76    language_model: MLlamaTextModel,
77    multi_modal_projector: Linear,
78    hidden_size: usize,
79    dtype: DType,
80}
81
82impl MLlamaModel {
83    pub(crate) fn new(
84        cfg: &MLlamaConfig,
85        vb: ShardedVarBuilder,
86        is_gptx: bool,
87        normal_loading_metadata: NormalLoadingMetadata,
88        attention_mechanism: AttentionImplementation,
89    ) -> Result<Self> {
90        let real_dev = normal_loading_metadata.real_device.clone();
91        Ok(Self {
92            vision_model: MLlamaVisionModel::new(
93                &cfg.vision_config,
94                vb.pp("vision_model"),
95                &real_dev,
96                &normal_loading_metadata.mapper.get_comm_for(0)?,
97            )?,
98            language_model: MLlamaTextModel::new(
99                &cfg.text_config,
100                vb.pp("language_model"),
101                is_gptx,
102                normal_loading_metadata,
103                attention_mechanism,
104            )?,
105            multi_modal_projector: linear(
106                cfg.vision_config.vision_output_dim,
107                cfg.text_config.hidden_size,
108                vb.pp("multi_modal_projector").set_device(real_dev.clone()),
109            )?,
110            hidden_size: cfg.text_config.hidden_size,
111            dtype: vb.dtype(),
112        })
113    }
114
115    #[allow(clippy::too_many_arguments)]
116    fn forward_inner(
117        &self,
118        input_ids: &Tensor,
119        pixel_values: Option<&Tensor>,
120        aspect_ratio_mask: Option<&Tensor>,
121        aspect_ratio_ids: Option<&Tensor>,
122        cross_attn_mask: Option<&Tensor>,
123        seqlen_offsets: &[usize],
124        context_lens: Vec<(usize, usize)>,
125    ) -> Result<Tensor> {
126        let cross_attn_states = if let Some(pixel_values) = pixel_values {
127            let Some(aspect_ratio_mask) = aspect_ratio_mask else {
128                candle_core::bail!("`aspect_ratio_mask` must be specified if `pixel_values` is.");
129            };
130            let Some(aspect_ratio_ids) = aspect_ratio_ids else {
131                candle_core::bail!("`aspect_ratio_ids` must be specified if `pixel_values` is.");
132            };
133            let vision_outputs =
134                self.vision_model
135                    .forward(pixel_values, aspect_ratio_ids, aspect_ratio_mask)?;
136            let cross_attention_states = self
137                .multi_modal_projector
138                .forward(&vision_outputs.flatten(0, 1)?)?
139                .reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?
140                .to_dtype(self.dtype)?;
141            Some(cross_attention_states)
142        } else {
143            None
144        };
145
146        let (cross_attn_mask, full_text_row_masked_out_mask) =
147            if let Some(cross_attn_mask) = cross_attn_mask {
148                let (mut cmask, fmask) = prepare_cross_attention_mask(
149                    cross_attn_mask,
150                    self.vision_model.num_patches,
151                    self.dtype,
152                )?;
153                cmask = cmask.squeeze(1)?;
154                (Some(cmask), Some(fmask))
155            } else {
156                (None, None)
157            };
158
159        self.language_model.forward(
160            input_ids,
161            cross_attn_states.as_ref(),
162            cross_attn_mask.as_ref(),
163            full_text_row_masked_out_mask.as_ref(),
164            seqlen_offsets,
165            context_lens,
166        )
167    }
168}
169
170#[derive(Default)]
171pub(crate) struct MLlamaSpecificArgs {
172    pub aspect_ratio_ids: Option<Tensor>,
173    pub aspect_ratio_mask: Option<Tensor>,
174    pub cross_attn_mask: Option<Tensor>,
175}
176
177impl VisionModel for MLlamaModel {
178    fn cache(&self) -> &EitherCache {
179        &self.language_model.cache
180    }
181    fn cache_mut(&mut self) -> &mut EitherCache {
182        &mut self.language_model.cache
183    }
184    fn config(&self) -> &ModelConfigMetadata {
185        &self.language_model.cfg
186    }
187    fn device(&self) -> &Device {
188        &self.language_model.device
189    }
190    fn has_conv2d(&self) -> bool {
191        true
192    }
193    fn max_seq_len(&self) -> usize {
194        self.language_model.max_position_embeddings
195    }
196    fn forward(
197        &self,
198        input_ids: &Tensor,
199        pixel_values: Option<Tensor>,
200        seqlen_offsets: &[usize],
201        context_lens: Vec<(usize, usize)>,
202        _position_ids: Vec<usize>,
203        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
204        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
205        _flash_params: &FlashParams,
206    ) -> Result<Tensor> {
207        let MLlamaSpecificArgs {
208            aspect_ratio_ids,
209            aspect_ratio_mask,
210            cross_attn_mask,
211        } = *model_specific_args
212            .downcast()
213            .expect("Cannot downcast into `MLlamaSpecificArgs`");
214        self.forward_inner(
215            input_ids,
216            pixel_values.as_ref(),
217            aspect_ratio_mask.as_ref(),
218            aspect_ratio_ids.as_ref(),
219            cross_attn_mask.as_ref(),
220            seqlen_offsets,
221            context_lens,
222        )
223    }
224    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
225        Box::new(MLlamaSpecificArgs::default())
226    }
227}
228
229impl IsqModel for MLlamaModel {
230    fn get_layers(
231        &mut self,
232    ) -> (
233        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
234        &dyn DeviceMapper,
235    ) {
236        let (mut layers, mapper) = self.language_model.get_layers();
237        layers.extend(
238            self.vision_model
239                .get_isq_layers()
240                .into_iter()
241                .map(|layer| (layer, None)),
242        );
243        (layers, mapper)
244    }
245
246    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
247        let uvb = UnVarBuilder::new();
248
249        uvb.pp("multi_modal_projector")
250            .add(&self.multi_modal_projector);
251        uvb.pp("language_model")
252            .extend(self.language_model.residual_tensors());
253        uvb.pp("vision_model")
254            .extend(self.vision_model.residual_tensors());
255
256        uvb.to_safetensors()
257    }
258
259    // NOTE: We ONLY calibrate the text bits of these models, so we should only track/return those parts!!
260
261    /// This is used for imatrix generation internally. Begin stats tracking.
262    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
263        let layers = self
264            .language_model
265            .get_layers()
266            .0
267            .into_iter()
268            .map(|(layer, _)| layer)
269            .collect::<Vec<_>>();
270        for layer in layers {
271            Arc::get_mut(layer).unwrap().begin_track_stats()?;
272        }
273        Ok(())
274    }
275
276    /// End stats tracking and return the imatrix data
277    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
278        let layers = self
279            .language_model
280            .get_layers()
281            .0
282            .into_iter()
283            .enumerate()
284            .map(|(i, (layer, _))| (i, layer))
285            .collect::<Vec<_>>();
286        let mut data = HashMap::new();
287        for (i, layer) in layers {
288            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
289        }
290        Ok(CollectedImatrixData(data))
291    }
292}
293
294impl AnyMoeBaseModelMixin for MLlamaModel {}