mistralrs_core/vision_models/phi4/
mm_embedding.rs

1use candle_core::{Result, Tensor, D};
2use candle_nn::Module;
3use mistralrs_quant::ShardedVarBuilder;
4
5use crate::utils::unvarbuilder::UnVarBuilder;
6
7use super::{image_embedding::ImageEmbedding, Phi4MMConfig};
8
9const MAX_INPUT_ID: f64 = 1e9;
10
11pub struct Phi4MMImageAudioEmbedding {
12    image_embed: Option<ImageEmbedding>,
13    image_input_id: f64,
14    wte: candle_nn::Embedding,
15}
16
17impl Phi4MMImageAudioEmbedding {
18    pub fn new(
19        cfg: &Phi4MMConfig,
20        wte: candle_nn::Embedding,
21        vb: ShardedVarBuilder,
22    ) -> Result<Self> {
23        let image_embed = if let Some(img_embd_config) = &cfg.embd_layer.image_embd_layer {
24            Some(ImageEmbedding::new(
25                cfg,
26                img_embd_config,
27                wte.clone(),
28                vb.pp("image_embed"),
29            )?)
30        } else {
31            None
32        };
33
34        Ok(Self {
35            image_embed,
36            image_input_id: cfg.image_input_id.unwrap_or(-1.),
37            wte,
38        })
39    }
40
41    pub fn forward(
42        &self,
43        input_ids: &Tensor,
44        input_image_embeds: &Tensor,
45        image_attention_mask: Option<&Tensor>,
46        image_sizes: Option<Vec<(u32, u32)>>,
47    ) -> Result<Tensor> {
48        assert!(-MAX_INPUT_ID < self.image_input_id);
49
50        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
51
52        let image_hidden_states = if let Some(image_embed) = &self.image_embed {
53            Some(image_embed.forward(
54                &input_ids,
55                input_image_embeds,
56                image_attention_mask,
57                image_sizes,
58            )?)
59        } else {
60            None
61        };
62
63        match image_hidden_states {
64            Some(image_hidden_states) => Ok(image_hidden_states),
65
66            None => self.wte.forward(&input_ids),
67        }
68    }
69
70    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71        let uvb = UnVarBuilder::new();
72
73        if let Some(image_embed) = &self.image_embed {
74            uvb.pp("image_embed").extend(image_embed.residual_tensors());
75        }
76
77        uvb.to_safetensors()
78    }
79}