mistralrs_core/vision_models/phi4/
config.rs1use std::collections::HashMap;
2
3use mistralrs_quant::{QuantizedConfig, StaticLoraConfig};
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 layers::{Activation, Phi4MMRopeScalingConfig},
8 serde_default_fn,
9};
10
11serde_default_fn!(bool, d_flash_attn, false);
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Phi4MMLoraConfig {
15 pub layer: String,
16 pub lora_alpha: f64,
17 pub r: usize,
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct Phi4MMImageEmbedConfig {
22 pub n_embd: Option<usize>,
23 pub crop_size: Option<usize>,
24 pub embedding_cls: String,
25 pub enable_gradient_checkpointing: bool,
26 pub hd_transform_order: Option<String>,
27 pub image_token_compression_cls: Option<String>,
28 pub projection_cls: Option<String>,
29 pub use_hd_transform: Option<bool>,
30 pub with_learnable_separator: Option<bool>,
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct Phi4MMEmbdLayerConfig {
35 pub image_embd_layer: Option<Phi4MMImageEmbedConfig>,
36}
37
38#[derive(Serialize, Deserialize, Debug, Clone)]
39pub struct Phi4MMImgProcessorConfig {
40 pub layer_idx: Option<isize>,
41 pub type_feature: Option<String>,
42}
43
44#[derive(Serialize, Deserialize, Debug, Clone)]
45pub struct Phi4MMConfig {
46 pub vocab_size: usize,
47 pub hidden_size: usize,
48 pub intermediate_size: usize,
49 pub num_hidden_layers: usize,
50 pub num_attention_heads: usize,
51 pub num_key_value_heads: Option<usize>,
52 pub resid_pdrop: f64,
53 pub embd_pdrop: f64,
54 pub attention_dropout: f64,
55 pub hidden_act: Activation,
56 pub max_position_embeddings: usize,
57 pub original_max_position_embeddings: usize,
58 pub initializer_range: f64,
59 pub rms_norm_eps: f64,
60 pub use_cache: bool,
61 pub tie_word_embeddings: bool,
62 pub rope_theta: f64,
63 pub rope_scaling: Option<Phi4MMRopeScalingConfig>,
64 pub partial_rotary_factor: f64,
65 pub bos_token_id: usize,
66 pub eos_token_id: usize,
67 pub pad_token_id: usize,
68 pub image_input_id: Option<f64>,
69 pub sliding_window: Option<usize>,
70 pub embd_layer: Phi4MMEmbdLayerConfig,
71 pub img_processor: Option<Phi4MMImgProcessorConfig>,
72 pub vision_lora: StaticLoraConfig,
74 pub speech_lora: StaticLoraConfig,
75 pub quantization_config: Option<QuantizedConfig>,
76 #[serde(default = "d_flash_attn")]
77 pub use_flash_attn: bool,
78}
79
80impl Phi4MMConfig {
81 pub fn num_key_value_heads(&self) -> usize {
82 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
83 }
84
85 pub fn head_dim(&self) -> usize {
86 self.hidden_size / self.num_attention_heads
87 }
88
89 pub fn loras(&self) -> HashMap<String, StaticLoraConfig> {
90 let mut accum = HashMap::new();
91 accum.insert("speech".to_string(), self.speech_lora.clone());
93 accum.insert("vision".to_string(), self.vision_lora.clone());
94 accum
95 }
96}