1use std::sync::Arc;
2
3use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
4use mistralrs_quant::{linear_b, QuantMethod, ShardedVarBuilder};
5
6use crate::{
7 layers::{self, GetFloatInfo, RmsNorm},
8 pipeline::NormalLoadingMetadata,
9};
10
11fn default_act() -> candle_nn::Activation {
12 candle_nn::Activation::Silu
13}
14
15fn default_hidden_size() -> usize {
16 1024
17}
18
19fn default_intermediate_size() -> usize {
20 4096
21}
22
23fn default_num_channels() -> usize {
24 3
25}
26
27fn default_num_hidden_layers() -> usize {
28 24
29}
30
31fn default_num_attention_heads() -> usize {
32 16
33}
34
35#[derive(serde::Deserialize, Debug, Clone)]
36pub struct Mistral3VisionConfig {
37 #[serde(default = "default_hidden_size")]
38 pub hidden_size: usize,
39 #[serde(default = "default_num_channels")]
40 pub num_channels: usize,
41 pub image_size: usize,
42 pub patch_size: usize,
43 pub rope_theta: f64,
44 #[serde(default = "default_intermediate_size")]
45 pub intermediate_size: usize,
46 #[serde(default = "default_num_hidden_layers")]
47 pub num_hidden_layers: usize,
48 pub head_dim: Option<usize>,
49 #[serde(default = "default_num_attention_heads")]
50 pub num_attention_heads: usize,
51 #[serde(default = "default_act")]
52 pub hidden_act: candle_nn::Activation,
53}
54
55impl Mistral3VisionConfig {
56 fn head_dim(&self) -> usize {
57 self.head_dim
58 .unwrap_or(self.hidden_size / self.num_attention_heads)
59 }
60}
61
62#[derive(Debug, Clone)]
63struct Attention {
64 q_proj: Arc<dyn QuantMethod>,
65 k_proj: Arc<dyn QuantMethod>,
66 v_proj: Arc<dyn QuantMethod>,
67 o_proj: Arc<dyn QuantMethod>,
68 scale: f64,
69 num_heads: usize,
70 head_dim: usize,
71}
72
73impl Attention {
74 fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
75 let h = cfg.hidden_size;
76 let num_heads = cfg.num_attention_heads;
77 let head_dim = cfg.head_dim();
78 let q_proj = linear_b(h, h, false, &None, vb.pp("q_proj"))?;
79 let k_proj = linear_b(h, h, false, &None, vb.pp("k_proj"))?;
80 let v_proj = linear_b(h, h, false, &None, vb.pp("v_proj"))?;
81 let o_proj = linear_b(h, h, false, &None, vb.pp("o_proj"))?;
82 let scale = (head_dim as f64).powf(-0.5);
83 Ok(Self {
84 q_proj,
85 k_proj,
86 v_proj,
87 o_proj,
88 scale,
89 num_heads,
90 head_dim,
91 })
92 }
93
94 fn forward(
95 &self,
96 xs: &Tensor,
97 emb: &RotaryEmbedding,
98 subsampled_positions: Option<&Tensor>,
99 attention_mask: Option<&Tensor>,
100 ) -> Result<Tensor> {
101 let (b, patches, _) = xs.dims3()?;
102 let query_states = self.q_proj.forward_autocast(xs)?;
103 let key_states = self.k_proj.forward_autocast(xs)?;
104 let value_states = self.v_proj.forward_autocast(xs)?;
105
106 let shape = (b, patches, self.num_heads, self.head_dim);
107 let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
108 let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
109 let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
110
111 let (query_states, key_states) =
112 emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
113 let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
114
115 let attn_weights = match attention_mask {
116 None => attn_weights,
117 Some(mask) => attn_weights.broadcast_add(mask)?,
118 };
119
120 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
121
122 self.o_proj.forward_autocast(
123 &attn_weights
124 .matmul(&value_states)?
125 .transpose(1, 2)?
126 .reshape((b, patches, ()))?,
127 )
128 }
129}
130
131#[derive(Debug, Clone)]
132struct Mlp {
133 gate_proj: Arc<dyn QuantMethod>,
134 up_proj: Arc<dyn QuantMethod>,
135 down_proj: Arc<dyn QuantMethod>,
136 act_fn: candle_nn::Activation,
137}
138
139impl Mlp {
140 fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
141 let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
142 let gate_proj = linear_b(h, i, false, &None, vb.pp("gate_proj"))?;
143 let up_proj = linear_b(h, i, false, &None, vb.pp("up_proj"))?;
144 let down_proj = linear_b(i, h, false, &None, vb.pp("down_proj"))?;
145 Ok(Self {
146 gate_proj,
147 up_proj,
148 down_proj,
149 act_fn: cfg.hidden_act,
150 })
151 }
152}
153
154impl Module for Mlp {
155 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
156 self.down_proj.forward_autocast(
157 &(self.gate_proj.forward_autocast(xs)?.apply(&self.act_fn)?
158 * self.up_proj.forward_autocast(xs)?)?,
159 )
160 }
161}
162
163#[derive(Debug, Clone)]
164struct AttentionLayer {
165 attention_norm: RmsNorm,
166 feed_forward: Mlp,
167 attention: Attention,
168 ffn_norm: RmsNorm,
169}
170
171impl AttentionLayer {
172 fn new(
173 cfg: &Mistral3VisionConfig,
174 vb: ShardedVarBuilder,
175 normal_loading_metadata: &NormalLoadingMetadata,
176 ) -> Result<Self> {
177 let attention_norm = RmsNorm::new(
178 cfg.hidden_size,
179 1e-5,
180 vb.pp("attention_norm")
181 .set_device(normal_loading_metadata.real_device.clone()),
182 )?;
183 let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
184 let attention = Attention::new(cfg, vb.pp("attention"))?;
185 let ffn_norm = RmsNorm::new(
186 cfg.hidden_size,
187 1e-5,
188 vb.pp("ffn_norm")
189 .set_device(normal_loading_metadata.real_device.clone()),
190 )?;
191 Ok(Self {
192 attention_norm,
193 feed_forward,
194 attention,
195 ffn_norm,
196 })
197 }
198
199 fn forward(
200 &self,
201 xs: &Tensor,
202 emb: &RotaryEmbedding,
203 subsampled_positions: Option<&Tensor>,
204 attention_mask: Option<&Tensor>,
205 ) -> Result<Tensor> {
206 let residual = xs;
207 let xs = self.attention.forward(
208 &xs.apply(&self.attention_norm)?,
209 emb,
210 subsampled_positions,
211 attention_mask,
212 )?;
213 let xs = (residual + xs)?;
214 let residual = &xs;
215 let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
216 xs + residual
217 }
218}
219
220#[derive(Debug, Clone)]
221struct Transformer {
222 layers: Vec<AttentionLayer>,
223}
224
225impl Transformer {
226 fn new(
227 cfg: &Mistral3VisionConfig,
228 vb: ShardedVarBuilder,
229 normal_loading_metadata: &NormalLoadingMetadata,
230 ) -> Result<Self> {
231 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
232 let vb = vb.pp("layers");
233 for layer_idx in 0..cfg.num_hidden_layers {
234 let layer = AttentionLayer::new(cfg, vb.pp(layer_idx), normal_loading_metadata)?;
235 layers.push(layer)
236 }
237 Ok(Self { layers })
238 }
239
240 fn forward(
241 &self,
242 xs: &Tensor,
243 emb: &RotaryEmbedding,
244 subsampled_positions: Option<&Tensor>,
245 attention_mask: Option<&Tensor>,
246 ) -> Result<Tensor> {
247 let mut xs = xs.clone();
248 for layer in self.layers.iter() {
249 xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
250 }
251 Ok(xs)
252 }
253}
254
255#[derive(Debug, Clone)]
256struct RotaryEmbedding {
257 cos: Tensor,
258 sin: Tensor,
259}
260
261impl RotaryEmbedding {
262 fn new(cfg: &Mistral3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
263 let dtype = vb.dtype();
264 let dev = vb.device();
265 let dim = cfg.head_dim();
266 let rope_theta = cfg.rope_theta as f32;
267 let max_patches_per_side = cfg.image_size / cfg.patch_size;
268 let freqs: Vec<_> = (0..dim)
269 .step_by(2)
270 .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
271 .collect();
272 let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
273 let freqs_h = Tensor::new(freqs_h, dev)?;
274 let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
275 let freqs_w = Tensor::new(freqs_w, dev)?;
276 let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
277 let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
278 let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
279 let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
280 let inv_freq = Tensor::cat(
281 &[
282 freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
283 freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
284 ],
285 D::Minus1,
286 )?
287 .reshape(((), dim / 2))?;
288 let cos = inv_freq.cos()?.to_dtype(dtype)?;
289 let sin = inv_freq.sin()?.to_dtype(dtype)?;
290 Ok(Self { cos, sin })
291 }
292
293 fn apply_rotary_emb_qkv(
294 &self,
295 q: &Tensor,
296 k: &Tensor,
297 subsampled_positions: Option<&Tensor>,
298 ) -> Result<(Tensor, Tensor)> {
299 let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
300 let (cos, sin) = match subsampled_positions {
301 None => (&self.cos, &self.sin),
302 Some(pos) => (
303 &self.cos.index_select(pos, 0)?,
304 &self.sin.index_select(pos, 0)?,
305 ),
306 };
307 let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
308 let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
309 Ok((q_embed, k_embed))
310 }
311}
312
313#[derive(Debug, Clone)]
314pub struct Mistral3VisionModel {
315 patch_conv: candle_nn::Conv2d,
316 ln_pre: RmsNorm,
317 transformer: Transformer,
318 patch_positional_embedding: RotaryEmbedding,
319 max_image_width: u32,
320 patch_size: usize,
321 dtype: DType,
322}
323
324impl Mistral3VisionModel {
325 pub fn new(
326 cfg: &Mistral3VisionConfig,
327 vb: ShardedVarBuilder,
328 normal_loading_metadata: &NormalLoadingMetadata,
329 ) -> Result<Self> {
330 let conv2d_cfg = candle_nn::Conv2dConfig {
331 stride: cfg.patch_size,
332 ..Default::default()
333 };
334 let patch_conv = layers::conv2d_no_bias(
335 cfg.num_channels,
336 cfg.hidden_size,
337 cfg.patch_size,
338 conv2d_cfg,
339 vb.pp("patch_conv")
340 .set_device(normal_loading_metadata.real_device.clone()),
341 )?;
342 let ln_pre = RmsNorm::new(
343 cfg.hidden_size,
344 1e-5,
345 vb.pp("ln_pre")
346 .set_device(normal_loading_metadata.real_device.clone()),
347 )?;
348 let transformer = Transformer::new(cfg, vb.pp("transformer"), normal_loading_metadata)?;
349 let patch_positional_embedding = RotaryEmbedding::new(
350 cfg,
351 vb.pp("patch_positional_embedding")
352 .set_device(normal_loading_metadata.real_device.clone()),
353 )?;
354 let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
355 Ok(Self {
356 patch_conv,
357 ln_pre,
358 transformer,
359 patch_positional_embedding,
360 max_image_width,
361 patch_size: cfg.patch_size,
362 dtype: vb.dtype(),
363 })
364 }
365
366 fn position_ids_in_meshgrid(
367 &self,
368 patch_embeds_list: &Vec<Tensor>,
369 device: &Device,
370 ) -> Result<Tensor> {
371 let mut positions = Vec::new();
372 for patch in patch_embeds_list {
373 let (height, width) = (patch.dim(D::Minus2)?, patch.dim(D::Minus1)?);
374 let idx = Tensor::arange(0, height as u32, device)?;
375 let idy = Tensor::arange(0, width as u32, device)?;
376 let mesh = Tensor::meshgrid(&[idx, idy], false)?;
377 let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
378 positions.push(ids);
379 }
380 Tensor::cat(&positions, 0)
381 }
382
383 fn generate_block_attention_mask(
384 &self,
385 patch_embeds_list: Vec<usize>,
386 patch_embeds: &Tensor,
387 ) -> Result<Tensor> {
388 let seq_len = patch_embeds.dim(1)?;
389 let mut causal_mask = (Tensor::ones(
390 (seq_len, seq_len),
391 patch_embeds.dtype(),
392 patch_embeds.device(),
393 )? * patch_embeds.dtype().finfo()?.min)?;
394
395 let block_end_idx: Vec<usize> = patch_embeds_list.iter().fold(Vec::new(), |mut acc, &x| {
396 let new_sum = x + acc.last().copied().unwrap_or(0);
397 acc.push(new_sum);
398 acc
399 });
400 let block_start_idx: Vec<usize> = {
401 let mut extended = vec![0];
402 extended.extend_from_slice(&patch_embeds_list[..patch_embeds_list.len() - 1]);
403 extended.into_iter().fold(Vec::new(), |mut acc, x| {
404 let new_sum = x + acc.last().copied().unwrap_or(0);
405 acc.push(new_sum);
406 acc
407 })
408 };
409 for (start, end) in block_start_idx.into_iter().zip(block_end_idx) {
410 causal_mask = causal_mask.slice_assign(
411 &[&(start..end), &(start..end)],
412 &Tensor::zeros(
413 (end - start, end - start),
414 causal_mask.dtype(),
415 causal_mask.device(),
416 )?,
417 )?;
418 }
419
420 causal_mask
421 .reshape((1, 1, causal_mask.dim(0)?, causal_mask.dim(1)?))?
422 .repeat((patch_embeds.dim(0)?, 1, 1, 1))
423 }
424
425 pub fn forward(&self, xs: &Tensor, image_sizes: Vec<(u32, u32)>) -> Result<Tensor> {
426 let patch_embeds = xs.apply(&self.patch_conv)?;
427 let patch_embeds_list = image_sizes
428 .iter()
429 .enumerate()
430 .map(|(i, &size)| {
431 patch_embeds
432 .i(i)?
433 .narrow(D::Minus2, 0, size.0 as usize / self.patch_size)?
434 .narrow(D::Minus1, 0, size.1 as usize / self.patch_size)
435 })
436 .collect::<Result<Vec<Tensor>>>()?;
437 let patch_embeds = Tensor::cat(
438 &patch_embeds_list
439 .iter()
440 .map(|p| p.flatten_from(1)?.t())
441 .collect::<Result<Vec<Tensor>>>()?,
442 0,
443 )?
444 .unsqueeze(0)?;
445 let patch_embeds = patch_embeds.apply(&self.ln_pre)?;
446
447 let subsampled_positions =
448 Some(self.position_ids_in_meshgrid(&patch_embeds_list, patch_embeds.device())?);
449
450 let attention_mask = self.generate_block_attention_mask(
451 patch_embeds_list
452 .iter()
453 .map(|p| Ok(p.dim(D::Minus2)? * p.dim(D::Minus1)?))
454 .collect::<Result<Vec<usize>>>()?,
455 &patch_embeds,
456 )?;
457
458 self.transformer.forward(
459 &patch_embeds,
460 &self.patch_positional_embedding,
461 subsampled_positions.as_ref(),
462 Some(&attention_mask),
463 )
464 }
465
466 pub fn dtype(&self) -> DType {
467 self.dtype
468 }
469
470 pub fn get_layers(&mut self) -> Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)> {
471 let mut tensors = Vec::new();
472 for layer in &mut self.transformer.layers {
473 tensors.push((&mut layer.attention.q_proj, None));
474 tensors.push((&mut layer.attention.k_proj, None));
475 tensors.push((&mut layer.attention.v_proj, None));
476 tensors.push((&mut layer.attention.o_proj, None));
477
478 tensors.push((&mut layer.feed_forward.gate_proj, None));
479 tensors.push((&mut layer.feed_forward.up_proj, None));
480 tensors.push((&mut layer.feed_forward.down_proj, None));
481 }
482 tensors
483 }
484}