mistralrs_core/vision_models/phi4/
mm_embedding.rs1use candle_core::{Result, Tensor, D};
2use candle_nn::Module;
3use mistralrs_quant::ShardedVarBuilder;
4
5use crate::utils::unvarbuilder::UnVarBuilder;
6
7use super::{image_embedding::ImageEmbedding, Phi4MMConfig};
8
9const MAX_INPUT_ID: f64 = 1e9;
10
11pub struct Phi4MMImageAudioEmbedding {
12 image_embed: Option<ImageEmbedding>,
13 image_input_id: f64,
14 wte: candle_nn::Embedding,
15}
16
17impl Phi4MMImageAudioEmbedding {
18 pub fn new(
19 cfg: &Phi4MMConfig,
20 wte: candle_nn::Embedding,
21 vb: ShardedVarBuilder,
22 ) -> Result<Self> {
23 let image_embed = if let Some(img_embd_config) = &cfg.embd_layer.image_embd_layer {
24 Some(ImageEmbedding::new(
25 cfg,
26 img_embd_config,
27 wte.clone(),
28 vb.pp("image_embed"),
29 )?)
30 } else {
31 None
32 };
33
34 Ok(Self {
35 image_embed,
36 image_input_id: cfg.image_input_id.unwrap_or(-1.),
37 wte,
38 })
39 }
40
41 pub fn forward(
42 &self,
43 input_ids: &Tensor,
44 input_image_embeds: &Tensor,
45 image_attention_mask: Option<&Tensor>,
46 image_sizes: Option<Vec<(u32, u32)>>,
47 ) -> Result<Tensor> {
48 assert!(-MAX_INPUT_ID < self.image_input_id);
49
50 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
51
52 let image_hidden_states = if let Some(image_embed) = &self.image_embed {
53 Some(image_embed.forward(
54 &input_ids,
55 input_image_embeds,
56 image_attention_mask,
57 image_sizes,
58 )?)
59 } else {
60 None
61 };
62
63 match image_hidden_states {
64 Some(image_hidden_states) => Ok(image_hidden_states),
65
66 None => self.wte.forward(&input_ids),
67 }
68 }
69
70 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71 let uvb = UnVarBuilder::new();
72
73 if let Some(image_embed) = &self.image_embed {
74 uvb.pp("image_embed").extend(image_embed.residual_tensors());
75 }
76
77 uvb.to_safetensors()
78 }
79}