1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::{
6 amoe::{AnyMoeBaseModelMixin, MlpLayer},
7 device_map::DeviceMapper,
8 layers::{self, Activation, RmsNorm},
9 models,
10 ops::{NonZeroOp, SplitOp},
11 paged_attention::{AttentionImplementation, ModelConfigMetadata},
12 pipeline::{
13 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
14 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
15 },
16 utils::unvarbuilder::UnVarBuilder,
17 AnyMoeConfig, AnyMoeExpertType,
18};
19use candle_core::{DType, Device, Result, Tensor, D};
20use candle_nn::{Linear, Module};
21pub use config::Mistral3Config;
22pub use inputs_processor::Mistral3Processor;
23use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
24use models::mistral::Model as Mistral;
25use vision::Mistral3VisionModel;
26
27mod config;
28mod inputs_processor;
29mod vision;
30
31struct Mistral3PatchMerger {
32 merging_layer: Linear,
33 spatial_merge_size: usize,
34 patch_size: usize,
35}
36
37impl Mistral3PatchMerger {
38 fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
39 Ok(Self {
40 merging_layer: layers::linear_no_bias(
41 cfg.vision_config.hidden_size * cfg.spatial_merge_size.pow(2),
42 cfg.vision_config.hidden_size,
43 vb.pp("merging_layer"),
44 )?,
45 spatial_merge_size: cfg.spatial_merge_size,
46 patch_size: cfg.vision_config.patch_size,
47 })
48 }
49
50 fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
51 let image_sizes = image_sizes
52 .iter()
53 .map(|&(h, w)| (h as usize / self.patch_size, w as usize / self.patch_size))
54 .collect::<Vec<_>>();
55
56 let tokens_per_image = image_sizes.iter().map(|&(h, w)| h * w).collect::<Vec<_>>();
57 let d = image_features.dim(D::Minus1)?;
58
59 let mut permuted_tensor = Vec::new();
60
61 for (image_index, image_tokens) in image_features
62 .split(&tokens_per_image, 0)?
63 .iter()
64 .enumerate()
65 {
66 let (h, w) = image_sizes[image_index];
67 let image_grid = image_tokens
68 .reshape((h, w, d))?
69 .permute((2, 0, 1))?
70 .unsqueeze(0)?;
71 let grid = {
74 let patches = image_grid
77 .unfold(2, self.spatial_merge_size, self.spatial_merge_size)?
78 .unfold(3, self.spatial_merge_size, self.spatial_merge_size)?;
79 let patches = patches.permute((0, 1, 4, 5, 2, 3))?;
82 patches.contiguous()?.reshape((
83 1,
84 d * self.spatial_merge_size * self.spatial_merge_size,
85 (),
86 ))?
87 };
88 let grid = grid
89 .reshape((d * self.spatial_merge_size.pow(2), ()))?
90 .t()?;
91 permuted_tensor.push(grid);
92 }
93
94 let image_features = Tensor::cat(&permuted_tensor, 0)?;
95
96 self.merging_layer.forward(&image_features)
97 }
98}
99
100struct Mistral3MultiModalProjector {
101 norm: RmsNorm,
102 linear_1: Linear,
103 linear_2: Linear,
104 act: Activation,
105 patch_merger: Mistral3PatchMerger,
106}
107
108impl Mistral3MultiModalProjector {
109 fn new(cfg: &Mistral3Config, vb: ShardedVarBuilder) -> Result<Self> {
110 let norm = RmsNorm::new(
111 cfg.vision_config.hidden_size,
112 cfg.text_config.rms_norm_eps,
113 vb.pp("norm"),
114 )?;
115 let num_feature_layers = 1;
120 let linear_1 = layers::linear_b(
121 cfg.vision_config.hidden_size * num_feature_layers,
122 cfg.text_config.hidden_size,
123 cfg.multimodal_projector_bias,
124 vb.pp("linear_1"),
125 )?;
126 let linear_2 = layers::linear_b(
127 cfg.text_config.hidden_size,
128 cfg.text_config.hidden_size,
129 cfg.multimodal_projector_bias,
130 vb.pp("linear_2"),
131 )?;
132 let patch_merger = Mistral3PatchMerger::new(cfg, vb.pp("patch_merger"))?;
133 Ok(Self {
134 norm,
135 linear_1,
136 linear_2,
137 act: cfg.projector_hidden_act,
138 patch_merger,
139 })
140 }
141
142 fn forward(&self, image_features: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
143 let mut hidden_states = self.norm.forward(image_features)?;
144 hidden_states = self.patch_merger.forward(&hidden_states, image_sizes)?;
145 hidden_states = self.linear_1.forward(&hidden_states)?.apply(&self.act)?;
146 self.linear_2.forward(&hidden_states)
147 }
148
149 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
150 let uvb = UnVarBuilder::new();
151
152 uvb.pp("norm").add(&self.norm);
153 uvb.pp("linear_1").add(&self.linear_1);
154 uvb.pp("linear_2").add(&self.linear_2);
155 uvb.pp("patch_merger")
156 .pp("merging_layer")
157 .add(&self.patch_merger.merging_layer);
158
159 uvb.to_safetensors()
160 }
161}
162
163pub struct Mistral3Model {
164 text_model: Mistral,
165 vision_model: Mistral3VisionModel,
166 mmproj: Mistral3MultiModalProjector,
167 cfg: Mistral3Config,
168}
169
170impl Mistral3Model {
171 pub fn new(
172 cfg: &Mistral3Config,
173 vb: ShardedVarBuilder,
174 is_gptx: bool,
175 normal_loading_metadata: NormalLoadingMetadata,
176 attention_mechanism: AttentionImplementation,
177 ) -> Result<Self> {
178 let vision_model = Mistral3VisionModel::new(
179 &cfg.vision_config,
180 vb.pp("vision_tower"),
181 &normal_loading_metadata,
182 )?;
183 let mmproj = Mistral3MultiModalProjector::new(
184 cfg,
185 vb.pp("multi_modal_projector")
186 .set_device(normal_loading_metadata.real_device.clone()),
187 )?;
188 let text_model = Mistral::new(
189 &cfg.text_config,
190 vb.pp("language_model"),
191 is_gptx,
192 normal_loading_metadata,
193 attention_mechanism,
194 )?;
195
196 assert_eq!(cfg.vision_feature_layer, -1);
198
199 Ok(Self {
200 vision_model,
201 text_model,
202 mmproj,
203 cfg: cfg.clone(),
204 })
205 }
206
207 fn get_image_features(
208 &self,
209 image_features: &Tensor,
210 image_sizes: Vec<(u32, u32)>,
211 ) -> Result<Tensor> {
212 let image_outputs = self
213 .vision_model
214 .forward(image_features, image_sizes.clone())?;
215 let selected_image_feature = image_outputs;
216 self.mmproj
217 .forward(&selected_image_feature.squeeze(0)?, image_sizes)
218 }
219
220 #[allow(clippy::too_many_arguments)]
221 pub fn forward(
222 &self,
223 input_ids: &Tensor,
224 pixel_values: Option<Tensor>,
225 seqlen_offsets: &[usize],
226 context_lens: Vec<(usize, usize)>,
227 image_sizes: Option<Vec<(u32, u32)>>,
228 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
229 flash_params: &FlashParams,
230 ) -> Result<Tensor> {
231 let mut input_embeds = self.text_model.get_input_embeddings(input_ids)?;
232
233 if let Some(pixel_values) = pixel_values {
234 let image_sizes = image_sizes.unwrap();
235 let image_features = self.get_image_features(
236 &pixel_values.to_dtype(self.vision_model.dtype())?,
237 image_sizes,
238 )?;
239
240 let special_image_mask = input_ids
241 .eq(self.cfg.image_token_index as f64)?
242 .unsqueeze(D::Minus1)?
243 .broadcast_as(input_embeds.shape().clone())?
244 .to_dtype(DType::U32)?;
245
246 let mask_flat = special_image_mask.flatten_all()?;
247 let mut x_flat = input_embeds.flatten_all()?;
248 let src_flat = image_features.flatten_all()?;
249
250 let indices = mask_flat.nonzero()?.squeeze(1)?;
251 let current_vals = x_flat.gather(&indices, 0)?;
252 let diff = (src_flat - current_vals)?;
253 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
254
255 input_embeds = x_flat.reshape(input_embeds.shape())?;
256 }
257
258 self.text_model.forward_embeds(
259 input_ids,
260 input_embeds,
261 seqlen_offsets,
262 context_lens,
263 metadata,
264 flash_params,
265 )
266 }
267}
268
269impl IsqModel for Mistral3Model {
270 fn get_layers(
271 &mut self,
272 ) -> (
273 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
274 &dyn DeviceMapper,
275 ) {
276 let (mut tensors, mapper) = self.text_model.get_layers();
277 tensors.extend(self.vision_model.get_layers());
278 (tensors, mapper)
279 }
280
281 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
282 let uvb = UnVarBuilder::new();
283 uvb.pp("multi_modal_projector")
284 .extend(self.mmproj.residual_tensors());
285 uvb.pp("language_model")
286 .extend(self.text_model.residual_tensors());
287
288 uvb.to_safetensors()
289 }
290
291 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
292 self.text_model.imatrix_names()
293 }
294}
295
296#[derive(Default)]
297pub struct Mistral3SpecificArgs {
298 pub image_sizes: Option<Vec<(u32, u32)>>,
299}
300
301impl VisionModel for Mistral3Model {
302 fn forward(
303 &self,
304 input_ids: &Tensor,
305 pixel_values: Option<Tensor>,
306 seqlen_offsets: &[usize],
307 context_lens: Vec<(usize, usize)>,
308 _position_ids: Vec<usize>,
309 model_specific_args: Box<dyn std::any::Any>,
310 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
311 flash_params: &FlashParams,
312 ) -> candle_core::Result<Tensor> {
313 let Mistral3SpecificArgs { image_sizes } = *model_specific_args
314 .downcast()
315 .expect("Cannot downcast into `Mistral3SpecificArgs`");
316 self.forward(
317 input_ids,
318 pixel_values,
319 seqlen_offsets,
320 context_lens,
321 image_sizes,
322 metadata,
323 flash_params,
324 )
325 }
326 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
327 Box::new(Mistral3SpecificArgs::default())
328 }
329 fn cache(&self) -> &EitherCache {
330 self.text_model.cache()
331 }
332 fn cache_mut(&mut self) -> &mut EitherCache {
333 self.text_model.cache_mut()
334 }
335 fn device(&self) -> &Device {
336 self.text_model.device()
337 }
338 fn max_seq_len(&self) -> usize {
339 self.text_model.max_seq_len()
340 }
341 fn config(&self) -> &ModelConfigMetadata {
342 self.text_model.config()
343 }
344 fn has_conv2d(&self) -> bool {
345 true
346 }
347}
348
349impl AnyMoeBaseModelMixin for Mistral3Model {
350 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
351 self.text_model.get_mlps()
352 }
353 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
354 self.text_model.get_mlps_mut()
355 }
356 fn create_anymoe_layers(
357 &mut self,
358 additional_vbs: Vec<ShardedVarBuilder>,
359 config: AnyMoeConfig,
360 (prefix, mlp): (String, String),
361 layers: Vec<usize>,
362 expert_type: AnyMoeExpertType,
363 gate_vb: Option<ShardedVarBuilder>,
364 ) -> Result<()> {
365 self.text_model.create_anymoe_layers(
366 additional_vbs,
367 config,
368 (prefix, mlp),
369 layers,
370 expert_type,
371 gate_vb,
372 )
373 }
374 fn amoe_supported(&self) -> bool {
375 true
376 }
377}