mistralrs_core/vision_models/phi4/
config.rs

1use std::collections::HashMap;
2
3use mistralrs_quant::{QuantizedConfig, StaticLoraConfig};
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    layers::{Activation, Phi4MMRopeScalingConfig},
8    serde_default_fn,
9};
10
11serde_default_fn!(bool, d_flash_attn, false);
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Phi4MMLoraConfig {
15    pub layer: String,
16    pub lora_alpha: f64,
17    pub r: usize,
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct Phi4MMImageEmbedConfig {
22    pub n_embd: Option<usize>,
23    pub crop_size: Option<usize>,
24    pub embedding_cls: String,
25    pub enable_gradient_checkpointing: bool,
26    pub hd_transform_order: Option<String>,
27    pub image_token_compression_cls: Option<String>,
28    pub projection_cls: Option<String>,
29    pub use_hd_transform: Option<bool>,
30    pub with_learnable_separator: Option<bool>,
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct Phi4MMEmbdLayerConfig {
35    pub image_embd_layer: Option<Phi4MMImageEmbedConfig>,
36}
37
38#[derive(Serialize, Deserialize, Debug, Clone)]
39pub struct Phi4MMImgProcessorConfig {
40    pub layer_idx: Option<isize>,
41    pub type_feature: Option<String>,
42}
43
44#[derive(Serialize, Deserialize, Debug, Clone)]
45pub struct Phi4MMConfig {
46    pub vocab_size: usize,
47    pub hidden_size: usize,
48    pub intermediate_size: usize,
49    pub num_hidden_layers: usize,
50    pub num_attention_heads: usize,
51    pub num_key_value_heads: Option<usize>,
52    pub resid_pdrop: f64,
53    pub embd_pdrop: f64,
54    pub attention_dropout: f64,
55    pub hidden_act: Activation,
56    pub max_position_embeddings: usize,
57    pub original_max_position_embeddings: usize,
58    pub initializer_range: f64,
59    pub rms_norm_eps: f64,
60    pub use_cache: bool,
61    pub tie_word_embeddings: bool,
62    pub rope_theta: f64,
63    pub rope_scaling: Option<Phi4MMRopeScalingConfig>,
64    pub partial_rotary_factor: f64,
65    pub bos_token_id: usize,
66    pub eos_token_id: usize,
67    pub pad_token_id: usize,
68    pub image_input_id: Option<f64>,
69    pub sliding_window: Option<usize>,
70    pub embd_layer: Phi4MMEmbdLayerConfig,
71    pub img_processor: Option<Phi4MMImgProcessorConfig>,
72    // pub audio_processor: Option<String>,
73    pub vision_lora: StaticLoraConfig,
74    pub speech_lora: StaticLoraConfig,
75    pub quantization_config: Option<QuantizedConfig>,
76    #[serde(default = "d_flash_attn")]
77    pub use_flash_attn: bool,
78}
79
80impl Phi4MMConfig {
81    pub fn num_key_value_heads(&self) -> usize {
82        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
83    }
84
85    pub fn head_dim(&self) -> usize {
86        self.hidden_size / self.num_attention_heads
87    }
88
89    pub fn loras(&self) -> HashMap<String, StaticLoraConfig> {
90        let mut accum = HashMap::new();
91        // Add all the loras
92        accum.insert("speech".to_string(), self.speech_lora.clone());
93        accum.insert("vision".to_string(), self.vision_lora.clone());
94        accum
95    }
96}