mistralrs_core/vision_models/llava/
config.rs1use serde::Deserialize;
2
3use crate::layers::{Activation, Llama3RopeConfig};
4use crate::serde_default_fn;
5
6use crate::models::llama::Config as LLaMAConfig;
7use crate::models::mistral::Config as MistralConfig;
8use crate::vision_models::clip::{Activation as ClipActivation, ClipConfig};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12 pub image_grid_pinpoints: Option<Vec<(u32, u32)>>,
13 pub projector_hidden_act: String,
14 pub text_config: LLaVATextConfig,
15 pub vision_config: LLaVAVisionConfig,
16 pub vision_feature_layer: isize,
17 pub vision_feature_select_strategy: String,
18 #[serde(default = "default_use_flash_attn")]
19 pub use_flash_attn: bool,
20}
21
22serde_default_fn!(bool, default_use_flash_attn, false);
23
24#[derive(Deserialize, Debug, Clone)]
25pub struct LLaVATextConfig {
26 #[serde(default = "default_hidden_size")]
27 pub hidden_size: usize,
28 #[serde(default = "default_intermediate_size")]
29 pub intermediate_size: usize,
30 #[serde(default = "default_max_length")]
31 pub max_length: usize,
32 pub max_position_embeddings: usize,
33 pub model_type: String,
34 #[serde(default = "default_num_attention_heads")]
35 pub num_attention_heads: usize,
36 #[serde(default = "default_num_hidden_layers")]
37 pub num_hidden_layers: usize,
38 #[serde(default = "default_num_key_value_heads")]
39 pub num_key_value_heads: usize,
40 pub rms_norm_eps: f64,
41 #[serde(default = "default_rope_theta")]
42 pub rope_theta: f32,
43 #[serde(default = "default_vocab_size")]
44 pub vocab_size: usize,
45 pub sliding_window: Option<usize>,
46 pub rope_scaling: Option<Llama3RopeConfig>,
47}
48
49serde_default_fn!(usize, default_num_hidden_layers, 32);
50serde_default_fn!(usize, default_hidden_size, 4096);
51serde_default_fn!(usize, default_intermediate_size, 11008);
52serde_default_fn!(usize, default_max_length, 4096);
53serde_default_fn!(usize, default_num_attention_heads, 32);
54serde_default_fn!(usize, default_num_key_value_heads, 32);
55serde_default_fn!(f32, default_rope_theta, 10000.0);
56serde_default_fn!(usize, default_vocab_size, 32064);
57
58#[derive(Deserialize, Debug, Clone)]
59pub struct LLaVAVisionConfig {
60 pub hidden_size: usize,
61 pub image_size: usize,
62 pub intermediate_size: usize,
63 pub num_attention_heads: usize,
64 pub num_hidden_layers: usize,
65 pub patch_size: usize,
66}
67
68impl Config {
69 pub fn to_llama_config(&self) -> LLaMAConfig {
70 LLaMAConfig {
71 hidden_size: self.text_config.hidden_size,
72 intermediate_size: self.text_config.intermediate_size,
73 vocab_size: self.text_config.vocab_size,
74 num_hidden_layers: self.text_config.num_hidden_layers,
75 num_attention_heads: self.text_config.num_attention_heads,
76 num_key_value_heads: self.text_config.num_key_value_heads,
77 use_flash_attn: self.use_flash_attn,
78 rms_norm_eps: self.text_config.rms_norm_eps,
79 rope_theta: self.text_config.rope_theta,
80 max_position_embeddings: self.text_config.max_position_embeddings,
81 rope_scaling: self.text_config.rope_scaling.clone(),
82 quantization_config: None,
83 tie_word_embeddings: false,
84 hidden_act: Activation::Silu,
85 }
86 }
87
88 pub fn to_mistral_config(&self) -> MistralConfig {
89 MistralConfig {
90 vocab_size: self.text_config.vocab_size,
91 hidden_size: self.text_config.hidden_size,
92 intermediate_size: self.text_config.intermediate_size,
93 num_hidden_layers: self.text_config.num_hidden_layers,
94 num_attention_heads: self.text_config.num_attention_heads,
95 num_key_value_heads: self.text_config.num_key_value_heads,
96 hidden_act: Activation::Silu, max_position_embeddings: self.text_config.max_position_embeddings,
98 rms_norm_eps: self.text_config.rms_norm_eps,
99 rope_theta: self.text_config.rope_theta as f64,
100 sliding_window: self.text_config.sliding_window,
101 use_flash_attn: self.use_flash_attn,
102 head_dim: None,
103 quantization_config: None,
104 tie_word_embeddings: false,
105 }
106 }
107
108 pub fn to_clip_config(&self) -> ClipConfig {
109 ClipConfig {
110 hidden_size: self.vision_config.hidden_size,
111 intermediate_size: self.vision_config.intermediate_size,
112 num_hidden_layers: self.vision_config.num_hidden_layers,
113 num_attention_heads: self.vision_config.num_attention_heads,
114 num_channels: 3,
115 image_size: self.vision_config.image_size,
116 patch_size: self.vision_config.patch_size,
117 hidden_act: ClipActivation::QuickGelu,
118 }
119 }
120}