1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::{
6 amoe::{AnyMoeBaseModelMixin, MlpLayer},
7 device_map::DeviceMapper,
8 layers::{self, Activation, RmsNorm},
9 models,
10 ops::SplitOp,
11 paged_attention::{AttentionImplementation, ModelConfigMetadata},
12 pipeline::{
13 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
14 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
15 },
16 utils::unvarbuilder::UnVarBuilder,
17 AnyMoeConfig, AnyMoeExpertType,
18};
19use candle_core::{DType, Device, Result, Tensor, D};
20use candle_nn::{Linear, Module};
21pub use config::Mistral3Config;
22pub use inputs_processor::Mistral3Processor;
23use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
24use models::mistral::Model as Mistral;
25use vision::Mistral3VisionModel;
26
27mod config;
28mod inputs_processor;
29mod vision;
30
31struct Mistral3PatchMerger {
32 merging_layer: Linear,
33 spatial_merge_size: usize,
34 patch_size: usize,
35}
36
37impl Mistral3PatchMerger {
38 fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
39 Ok(Self {
40 merging_layer: layers::linear_no_bias(
41 cfg.vision_config.hidden_size * cfg.spatial_merge_size.pow(2),
42 cfg.vision_config.hidden_size,
43 vb.pp("merging_layer"),
44 )?,
45 spatial_merge_size: cfg.spatial_merge_size,
46 patch_size: cfg.vision_config.patch_size,
47 })
48 }
49
50 fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
51 let image_sizes = image_sizes
52 .iter()
53 .map(|&(h, w)| (h as usize / self.patch_size, w as usize / self.patch_size))
54 .collect::<Vec<_>>();
55
56 let tokens_per_image = image_sizes.iter().map(|&(h, w)| h * w).collect::<Vec<_>>();
57 let d = image_features.dim(D::Minus1)?;
58
59 let mut permuted_tensor = Vec::new();
60
61 for (image_index, image_tokens) in image_features
62 .split(&tokens_per_image, 0)?
63 .iter()
64 .enumerate()
65 {
66 let (h, w) = image_sizes[image_index];
67 let image_grid = image_tokens
68 .reshape((h, w, d))?
69 .permute((2, 0, 1))?
70 .unsqueeze(0)?;
71 let grid = {
74 let patches = image_grid
77 .unfold(2, self.spatial_merge_size, self.spatial_merge_size)?
78 .unfold(3, self.spatial_merge_size, self.spatial_merge_size)?;
79 let patches = patches.permute((0, 1, 4, 5, 2, 3))?;
82 patches.contiguous()?.reshape((
83 1,
84 d * self.spatial_merge_size * self.spatial_merge_size,
85 (),
86 ))?
87 };
88 let grid = grid
89 .reshape((d * self.spatial_merge_size.pow(2), ()))?
90 .t()?;
91 permuted_tensor.push(grid);
92 }
93
94 let image_features = Tensor::cat(&permuted_tensor, 0)?;
95
96 self.merging_layer.forward(&image_features)
97 }
98}
99
100struct Mistral3MultiModalProjector {
101 norm: RmsNorm,
102 linear_1: Linear,
103 linear_2: Linear,
104 act: Activation,
105 patch_merger: Mistral3PatchMerger,
106}
107
108impl Mistral3MultiModalProjector {
109 fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
110 let norm = RmsNorm::new(
111 cfg.vision_config.hidden_size,
112 cfg.text_config.rms_norm_eps,
113 vb.pp("norm"),
114 )?;
115 let num_feature_layers = 1;
120 let linear_1 = layers::linear_b(
121 cfg.vision_config.hidden_size * num_feature_layers,
122 cfg.text_config.hidden_size,
123 cfg.multimodal_projector_bias,
124 vb.pp("linear_1"),
125 )?;
126 let linear_2 = layers::linear_b(
127 cfg.text_config.hidden_size,
128 cfg.text_config.hidden_size,
129 cfg.multimodal_projector_bias,
130 vb.pp("linear_2"),
131 )?;
132 let patch_merger = Mistral3PatchMerger::new(cfg, vb.pp("patch_merger"))?;
133 Ok(Self {
134 norm,
135 linear_1,
136 linear_2,
137 act: cfg.projector_hidden_act,
138 patch_merger,
139 })
140 }
141
142 fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
143 let mut hidden_states = self.norm.forward(image_features)?;
144 hidden_states = self.patch_merger.forward(&hidden_states, image_sizes)?;
145 hidden_states = self.linear_1.forward(&hidden_states)?.apply(&self.act)?;
146 self.linear_2.forward(&hidden_states)
147 }
148
149 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
150 let uvb = UnVarBuilder::new();
151
152 uvb.pp("norm").add(&self.norm);
153 uvb.pp("linear_1").add(&self.linear_1);
154 uvb.pp("linear_2").add(&self.linear_2);
155 uvb.pp("patch_merger")
156 .pp("merging_layer")
157 .add(&self.patch_merger.merging_layer);
158
159 uvb.to_safetensors()
160 }
161}
162
163pub struct Mistral3Model {
164 text_model: Mistral,
165 vision_model: Mistral3VisionModel,
166 mmproj: Mistral3MultiModalProjector,
167 cfg: Mistral3Config,
168}
169
170impl Mistral3Model {
171 pub fn new(
172 cfg: &Mistral3Config,
173 vb: ShardedVarBuilder,
174 is_gptx: bool,
175 normal_loading_metadata: NormalLoadingMetadata,
176 attention_mechanism: AttentionImplementation,
177 ) -> Result<Self> {
178 let vision_model = Mistral3VisionModel::new(
179 &cfg.vision_config,
180 vb.pp("vision_tower"),
181 &normal_loading_metadata,
182 )?;
183 let mmproj = Mistral3MultiModalProjector::new(
184 cfg,
185 vb.pp("multi_modal_projector")
186 .set_device(normal_loading_metadata.real_device.clone()),
187 )?;
188 let text_model = Mistral::new(
189 &cfg.text_config,
190 vb.pp("language_model"),
191 is_gptx,
192 normal_loading_metadata,
193 attention_mechanism,
194 )?;
195
196 assert_eq!(cfg.vision_feature_layer, -1);
198
199 Ok(Self {
200 vision_model,
201 text_model,
202 mmproj,
203 cfg: cfg.clone(),
204 })
205 }
206
207 fn get_image_features(
208 &self,
209 image_features: &Tensor,
210 image_sizes: Vec<(u32, u32)>,
211 ) -> Result<Tensor> {
212 let image_outputs = self
213 .vision_model
214 .forward(image_features, image_sizes.clone())?;
215 let selected_image_feature = image_outputs;
216 self.mmproj
217 .forward(&selected_image_feature.squeeze(0)?, image_sizes)
218 }
219
220 #[allow(clippy::too_many_arguments)]
221 pub fn forward(
222 &self,
223 input_ids: &Tensor,
224 pixel_values: Option<Tensor>,
225 seqlen_offsets: &[usize],
226 context_lens: Vec<(usize, usize)>,
227 image_sizes: Option<Vec<(u32, u32)>>,
228 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
229 flash_params: &FlashParams,
230 ) -> Result<Tensor> {
231 let mut input_embeds = self.text_model.get_input_embeddings(input_ids)?;
232
233 if let Some(pixel_values) = pixel_values {
234 let special_image_mask = input_ids
235 .eq(self.cfg.image_token_index as f64)?
236 .unsqueeze(D::Minus1)?
237 .broadcast_as(input_embeds.shape().clone())?
238 .to_dtype(DType::U32)?;
239 let mask_flat = special_image_mask.flatten_all()?;
240 let indices = mask_flat.nonzero()?.squeeze(1)?;
242
243 let image_sizes = image_sizes.unwrap();
244 let image_features = self.get_image_features(
245 &pixel_values.to_dtype(self.vision_model.dtype())?,
246 image_sizes,
247 )?;
248
249 let mut x_flat = input_embeds.flatten_all()?;
250 let src_flat = image_features.flatten_all()?;
251
252 let current_vals = x_flat.gather(&indices, 0)?;
253 let diff = (src_flat - current_vals)?;
254 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
255
256 input_embeds = x_flat.reshape(input_embeds.shape())?;
257 }
258
259 self.text_model.forward_embeds(
260 input_ids,
261 input_embeds,
262 seqlen_offsets,
263 context_lens,
264 metadata,
265 flash_params,
266 )
267 }
268}
269
270impl IsqModel for Mistral3Model {
271 fn get_layers(
272 &mut self,
273 ) -> (
274 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
275 &dyn DeviceMapper,
276 ) {
277 let (mut tensors, mapper) = self.text_model.get_layers();
278 tensors.extend(self.vision_model.get_layers());
279 (tensors, mapper)
280 }
281
282 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
283 let uvb = UnVarBuilder::new();
284 uvb.pp("multi_modal_projector")
285 .extend(self.mmproj.residual_tensors());
286 uvb.pp("language_model")
287 .extend(self.text_model.residual_tensors());
288
289 uvb.to_safetensors()
290 }
291
292 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
293 self.text_model.imatrix_names()
294 }
295}
296
297#[derive(Default)]
298pub struct Mistral3SpecificArgs {
299 pub image_sizes: Option<Vec<(u32, u32)>>,
300}
301
302impl VisionModel for Mistral3Model {
303 fn forward(
304 &self,
305 input_ids: &Tensor,
306 pixel_values: Option<Tensor>,
307 seqlen_offsets: &[usize],
308 context_lens: Vec<(usize, usize)>,
309 _position_ids: Vec<usize>,
310 model_specific_args: Box<dyn std::any::Any>,
311 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
312 flash_params: &FlashParams,
313 ) -> candle_core::Result<Tensor> {
314 let Mistral3SpecificArgs { image_sizes } = *model_specific_args
315 .downcast()
316 .expect("Cannot downcast into `Mistral3SpecificArgs`");
317 self.forward(
318 input_ids,
319 pixel_values,
320 seqlen_offsets,
321 context_lens,
322 image_sizes,
323 metadata,
324 flash_params,
325 )
326 }
327 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
328 Box::new(Mistral3SpecificArgs::default())
329 }
330 fn cache(&self) -> &EitherCache {
331 self.text_model.cache()
332 }
333 fn cache_mut(&mut self) -> &mut EitherCache {
334 self.text_model.cache_mut()
335 }
336 fn device(&self) -> &Device {
337 self.text_model.device()
338 }
339 fn max_seq_len(&self) -> usize {
340 self.text_model.max_seq_len()
341 }
342 fn config(&self) -> &ModelConfigMetadata {
343 self.text_model.config()
344 }
345 fn has_conv2d(&self) -> bool {
346 true
347 }
348}
349
350impl AnyMoeBaseModelMixin for Mistral3Model {
351 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
352 self.text_model.get_mlps()
353 }
354 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
355 self.text_model.get_mlps_mut()
356 }
357 fn create_anymoe_layers(
358 &mut self,
359 additional_vbs: Vec<ShardedVarBuilder>,
360 config: AnyMoeConfig,
361 (prefix, mlp): (String, String),
362 layers: Vec<usize>,
363 expert_type: AnyMoeExpertType,
364 gate_vb: Option<ShardedVarBuilder>,
365 ) -> Result<()> {
366 self.text_model.create_anymoe_layers(
367 additional_vbs,
368 config,
369 (prefix, mlp),
370 layers,
371 expert_type,
372 gate_vb,
373 )
374 }
375 fn amoe_supported(&self) -> bool {
376 true
377 }
378}