mistralrs_core/vision_models/mllama/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod config;
4mod inputs_processor;
5mod text;
6mod vision;
7
8use std::{any::Any, collections::HashMap, sync::Arc};
9
10pub(crate) use config::{MLlamaConfig, MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig};
11use config::{MLlamaVisionConfig, VisionActivation};
12pub(crate) use inputs_processor::MLlamaProcessor;
13use text::MLlamaTextModel;
14use vision::MLlamaVisionModel;
15
16use candle_core::{DType, Device, Result, Tensor, D};
17use candle_nn::{Linear, Module};
18use mistralrs_quant::{CollectedImatrixData, QuantMethod, ShardedVarBuilder};
19
20use crate::{
21 amoe::AnyMoeBaseModelMixin,
22 device_map::DeviceMapper,
23 layers::{linear, GetFloatInfo},
24 layers_masker::masked_fill,
25 ops::RepeatInterleaveOp,
26 paged_attention::{AttentionImplementation, ModelConfigMetadata},
27 pipeline::{
28 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
29 EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
30 },
31 utils::unvarbuilder::UnVarBuilder,
32};
33
34fn prepare_cross_attention_mask(
36 cross_attention_mask: &Tensor,
37 num_vision_tokens: usize,
38 dtype: DType,
39) -> Result<(Tensor, Tensor)> {
40 let bs = cross_attention_mask.dim(0)?;
41 let text_total_length = cross_attention_mask.dim(1)?;
42 let mut cross_attn_mask = cross_attention_mask
43 .to_dtype(DType::F32)?
44 .repeat_interleave(num_vision_tokens, 3)?;
45 cross_attn_mask = cross_attn_mask.reshape((bs, text_total_length, ()))?;
46 cross_attn_mask = cross_attn_mask.unsqueeze(1)?;
47
48 let inverted_cross_attn_mask = (1. - cross_attn_mask)?;
50 let neg_inf_value = dtype.finfo()?.min;
51 cross_attn_mask = masked_fill(
52 &inverted_cross_attn_mask,
53 &inverted_cross_attn_mask.ne(0.)?,
54 neg_inf_value as f32,
55 )?;
56
57 let full_text_row_masked_out_mask = cross_attn_mask
61 .ne(neg_inf_value)?
62 .sum(D::Minus1)?
63 .ne(0.)?
64 .unsqueeze(D::Minus1)?;
65
66 cross_attn_mask = cross_attn_mask
67 .broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?
68 .to_dtype(DType::F32)?
69 .to_dtype(dtype)?;
70
71 Ok((cross_attn_mask, full_text_row_masked_out_mask))
72}
73
74pub(crate) struct MLlamaModel {
75 vision_model: MLlamaVisionModel,
76 language_model: MLlamaTextModel,
77 multi_modal_projector: Linear,
78 hidden_size: usize,
79 dtype: DType,
80}
81
82impl MLlamaModel {
83 pub(crate) fn new(
84 cfg: &MLlamaConfig,
85 vb: ShardedVarBuilder,
86 is_gptx: bool,
87 normal_loading_metadata: NormalLoadingMetadata,
88 attention_mechanism: AttentionImplementation,
89 ) -> Result<Self> {
90 let real_dev = normal_loading_metadata.real_device.clone();
91 Ok(Self {
92 vision_model: MLlamaVisionModel::new(
93 &cfg.vision_config,
94 vb.pp("vision_model"),
95 &real_dev,
96 &normal_loading_metadata.mapper.get_comm_for(0)?,
97 )?,
98 language_model: MLlamaTextModel::new(
99 &cfg.text_config,
100 vb.pp("language_model"),
101 is_gptx,
102 normal_loading_metadata,
103 attention_mechanism,
104 )?,
105 multi_modal_projector: linear(
106 cfg.vision_config.vision_output_dim,
107 cfg.text_config.hidden_size,
108 vb.pp("multi_modal_projector").set_device(real_dev.clone()),
109 )?,
110 hidden_size: cfg.text_config.hidden_size,
111 dtype: vb.dtype(),
112 })
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 fn forward_inner(
117 &self,
118 input_ids: &Tensor,
119 pixel_values: Option<&Tensor>,
120 aspect_ratio_mask: Option<&Tensor>,
121 aspect_ratio_ids: Option<&Tensor>,
122 cross_attn_mask: Option<&Tensor>,
123 seqlen_offsets: &[usize],
124 context_lens: Vec<(usize, usize)>,
125 ) -> Result<Tensor> {
126 let cross_attn_states = if let Some(pixel_values) = pixel_values {
127 let Some(aspect_ratio_mask) = aspect_ratio_mask else {
128 candle_core::bail!("`aspect_ratio_mask` must be specified if `pixel_values` is.");
129 };
130 let Some(aspect_ratio_ids) = aspect_ratio_ids else {
131 candle_core::bail!("`aspect_ratio_ids` must be specified if `pixel_values` is.");
132 };
133 let vision_outputs =
134 self.vision_model
135 .forward(pixel_values, aspect_ratio_ids, aspect_ratio_mask)?;
136 let cross_attention_states = self
137 .multi_modal_projector
138 .forward(&vision_outputs.flatten(0, 1)?)?
139 .reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?
140 .to_dtype(self.dtype)?;
141 Some(cross_attention_states)
142 } else {
143 None
144 };
145
146 let (cross_attn_mask, full_text_row_masked_out_mask) =
147 if let Some(cross_attn_mask) = cross_attn_mask {
148 let (mut cmask, fmask) = prepare_cross_attention_mask(
149 cross_attn_mask,
150 self.vision_model.num_patches,
151 self.dtype,
152 )?;
153 cmask = cmask.squeeze(1)?;
154 (Some(cmask), Some(fmask))
155 } else {
156 (None, None)
157 };
158
159 self.language_model.forward(
160 input_ids,
161 cross_attn_states.as_ref(),
162 cross_attn_mask.as_ref(),
163 full_text_row_masked_out_mask.as_ref(),
164 seqlen_offsets,
165 context_lens,
166 )
167 }
168}
169
170#[derive(Default)]
171pub(crate) struct MLlamaSpecificArgs {
172 pub aspect_ratio_ids: Option<Tensor>,
173 pub aspect_ratio_mask: Option<Tensor>,
174 pub cross_attn_mask: Option<Tensor>,
175}
176
177impl VisionModel for MLlamaModel {
178 fn cache(&self) -> &EitherCache {
179 &self.language_model.cache
180 }
181 fn cache_mut(&mut self) -> &mut EitherCache {
182 &mut self.language_model.cache
183 }
184 fn config(&self) -> &ModelConfigMetadata {
185 &self.language_model.cfg
186 }
187 fn device(&self) -> &Device {
188 &self.language_model.device
189 }
190 fn has_conv2d(&self) -> bool {
191 true
192 }
193 fn max_seq_len(&self) -> usize {
194 self.language_model.max_position_embeddings
195 }
196 fn forward(
197 &self,
198 input_ids: &Tensor,
199 pixel_values: Option<Tensor>,
200 seqlen_offsets: &[usize],
201 context_lens: Vec<(usize, usize)>,
202 _position_ids: Vec<usize>,
203 model_specific_args: Box<dyn Any>, _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
205 _flash_params: &FlashParams,
206 ) -> Result<Tensor> {
207 let MLlamaSpecificArgs {
208 aspect_ratio_ids,
209 aspect_ratio_mask,
210 cross_attn_mask,
211 } = *model_specific_args
212 .downcast()
213 .expect("Cannot downcast into `MLlamaSpecificArgs`");
214 self.forward_inner(
215 input_ids,
216 pixel_values.as_ref(),
217 aspect_ratio_mask.as_ref(),
218 aspect_ratio_ids.as_ref(),
219 cross_attn_mask.as_ref(),
220 seqlen_offsets,
221 context_lens,
222 )
223 }
224 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
225 Box::new(MLlamaSpecificArgs::default())
226 }
227}
228
229impl IsqModel for MLlamaModel {
230 fn get_layers(
231 &mut self,
232 ) -> (
233 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
234 &dyn DeviceMapper,
235 ) {
236 let (mut layers, mapper) = self.language_model.get_layers();
237 layers.extend(
238 self.vision_model
239 .get_isq_layers()
240 .into_iter()
241 .map(|layer| (layer, None)),
242 );
243 (layers, mapper)
244 }
245
246 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
247 let uvb = UnVarBuilder::new();
248
249 uvb.pp("multi_modal_projector")
250 .add(&self.multi_modal_projector);
251 uvb.pp("language_model")
252 .extend(self.language_model.residual_tensors());
253 uvb.pp("vision_model")
254 .extend(self.vision_model.residual_tensors());
255
256 uvb.to_safetensors()
257 }
258
259 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
263 let layers = self
264 .language_model
265 .get_layers()
266 .0
267 .into_iter()
268 .map(|(layer, _)| layer)
269 .collect::<Vec<_>>();
270 for layer in layers {
271 Arc::get_mut(layer).unwrap().begin_track_stats()?;
272 }
273 Ok(())
274 }
275
276 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
278 let layers = self
279 .language_model
280 .get_layers()
281 .0
282 .into_iter()
283 .enumerate()
284 .map(|(i, (layer, _))| (i, layer))
285 .collect::<Vec<_>>();
286 let mut data = HashMap::new();
287 for (i, layer) in layers {
288 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
289 }
290 Ok(CollectedImatrixData(data))
291 }
292}
293
294impl AnyMoeBaseModelMixin for MLlamaModel {}