mistralrs_core/vision_models/llava/
config.rs

1use serde::Deserialize;
2
3use crate::layers::{Activation, Llama3RopeConfig};
4use crate::serde_default_fn;
5
6use crate::models::llama::Config as LLaMAConfig;
7use crate::models::mistral::Config as MistralConfig;
8use crate::vision_models::clip::{Activation as ClipActivation, ClipConfig};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12    pub image_grid_pinpoints: Option<Vec<(u32, u32)>>,
13    pub projector_hidden_act: String,
14    pub text_config: LLaVATextConfig,
15    pub vision_config: LLaVAVisionConfig,
16    pub vision_feature_layer: isize,
17    pub vision_feature_select_strategy: String,
18    #[serde(default = "default_use_flash_attn")]
19    pub use_flash_attn: bool,
20}
21
22serde_default_fn!(bool, default_use_flash_attn, false);
23
24#[derive(Deserialize, Debug, Clone)]
25pub struct LLaVATextConfig {
26    #[serde(default = "default_hidden_size")]
27    pub hidden_size: usize,
28    #[serde(default = "default_intermediate_size")]
29    pub intermediate_size: usize,
30    #[serde(default = "default_max_length")]
31    pub max_length: usize,
32    pub max_position_embeddings: usize,
33    pub model_type: String,
34    #[serde(default = "default_num_attention_heads")]
35    pub num_attention_heads: usize,
36    #[serde(default = "default_num_hidden_layers")]
37    pub num_hidden_layers: usize,
38    #[serde(default = "default_num_key_value_heads")]
39    pub num_key_value_heads: usize,
40    pub rms_norm_eps: f64,
41    #[serde(default = "default_rope_theta")]
42    pub rope_theta: f32,
43    #[serde(default = "default_vocab_size")]
44    pub vocab_size: usize,
45    pub sliding_window: Option<usize>,
46    pub rope_scaling: Option<Llama3RopeConfig>,
47}
48
49serde_default_fn!(usize, default_num_hidden_layers, 32);
50serde_default_fn!(usize, default_hidden_size, 4096);
51serde_default_fn!(usize, default_intermediate_size, 11008);
52serde_default_fn!(usize, default_max_length, 4096);
53serde_default_fn!(usize, default_num_attention_heads, 32);
54serde_default_fn!(usize, default_num_key_value_heads, 32);
55serde_default_fn!(f32, default_rope_theta, 10000.0);
56serde_default_fn!(usize, default_vocab_size, 32064);
57
58#[derive(Deserialize, Debug, Clone)]
59pub struct LLaVAVisionConfig {
60    pub hidden_size: usize,
61    pub image_size: usize,
62    pub intermediate_size: usize,
63    pub num_attention_heads: usize,
64    pub num_hidden_layers: usize,
65    pub patch_size: usize,
66}
67
68impl Config {
69    pub fn to_llama_config(&self) -> LLaMAConfig {
70        LLaMAConfig {
71            hidden_size: self.text_config.hidden_size,
72            intermediate_size: self.text_config.intermediate_size,
73            vocab_size: self.text_config.vocab_size,
74            num_hidden_layers: self.text_config.num_hidden_layers,
75            num_attention_heads: self.text_config.num_attention_heads,
76            num_key_value_heads: self.text_config.num_key_value_heads,
77            use_flash_attn: self.use_flash_attn,
78            rms_norm_eps: self.text_config.rms_norm_eps,
79            rope_theta: self.text_config.rope_theta,
80            max_position_embeddings: self.text_config.max_position_embeddings,
81            rope_scaling: self.text_config.rope_scaling.clone(),
82            quantization_config: None,
83            tie_word_embeddings: false,
84            hidden_act: Activation::Silu,
85        }
86    }
87
88    pub fn to_mistral_config(&self) -> MistralConfig {
89        MistralConfig {
90            vocab_size: self.text_config.vocab_size,
91            hidden_size: self.text_config.hidden_size,
92            intermediate_size: self.text_config.intermediate_size,
93            num_hidden_layers: self.text_config.num_hidden_layers,
94            num_attention_heads: self.text_config.num_attention_heads,
95            num_key_value_heads: self.text_config.num_key_value_heads,
96            hidden_act: Activation::Silu, // as it is in mistralai/Mistral-7B-Instruct-v0.2
97            max_position_embeddings: self.text_config.max_position_embeddings,
98            rms_norm_eps: self.text_config.rms_norm_eps,
99            rope_theta: self.text_config.rope_theta as f64,
100            sliding_window: self.text_config.sliding_window,
101            use_flash_attn: self.use_flash_attn,
102            head_dim: None,
103            quantization_config: None,
104            tie_word_embeddings: false,
105        }
106    }
107
108    pub fn to_clip_config(&self) -> ClipConfig {
109        ClipConfig {
110            hidden_size: self.vision_config.hidden_size,
111            intermediate_size: self.vision_config.intermediate_size,
112            num_hidden_layers: self.vision_config.num_hidden_layers,
113            num_attention_heads: self.vision_config.num_attention_heads,
114            num_channels: 3,
115            image_size: self.vision_config.image_size,
116            patch_size: self.vision_config.patch_size,
117            hidden_act: ClipActivation::QuickGelu,
118        }
119    }
120}