mistralrs_core/vision_models/gemma3/
config.rs

1use mistralrs_quant::QuantizedConfig;
2
3use crate::{
4    layers::{Activation, Gemma3RopeScalingConfig},
5    serde_default_fn,
6    vision_models::siglip::SiglipVisionConfig,
7};
8
9serde_default_fn!(bool, attention_bias, false);
10serde_default_fn!(usize, head_dim, 256);
11serde_default_fn!(Activation, hidden_activation, Activation::GeluPytorchTanh);
12serde_default_fn!(f64, rms_norm_eps, 1e-6);
13serde_default_fn!(f64, rope_theta, 1000000.);
14serde_default_fn!(usize, vocab_size, 262208);
15serde_default_fn!(bool, tie_word_embeddings, true);
16serde_default_fn!(usize, query_pre_attn_scalar, 256);
17serde_default_fn!(usize, max_position_embeddings, 131072);
18serde_default_fn!(bool, use_flash_attn, false);
19serde_default_fn!(f64, rope_local_base_freq, 10000.);
20serde_default_fn!(usize, sliding_window_pattern, 6);
21serde_default_fn!(usize, num_attention_heads, 8);
22serde_default_fn!(usize, num_key_value_heads, 4);
23
24#[derive(Debug, Clone, serde::Deserialize)]
25pub struct Gemma3TextConfig {
26    #[serde(default = "attention_bias")]
27    pub attention_bias: bool,
28    #[serde(default = "head_dim")]
29    pub head_dim: usize,
30    #[serde(default = "hidden_activation")]
31    pub hidden_activation: Activation,
32    pub hidden_size: usize,
33    pub intermediate_size: usize,
34    #[serde(default = "num_attention_heads")]
35    pub num_attention_heads: usize,
36    pub num_hidden_layers: usize,
37    #[serde(default = "num_key_value_heads")]
38    pub num_key_value_heads: usize,
39    #[serde(default = "rms_norm_eps")]
40    pub rms_norm_eps: f64,
41    #[serde(default = "rope_theta")]
42    pub rope_theta: f64,
43    #[serde(default = "vocab_size")]
44    pub vocab_size: usize,
45    pub sliding_window: usize,
46    pub attn_logit_softcapping: Option<f64>,
47    pub final_logit_softcapping: Option<f64>,
48    #[serde(default = "query_pre_attn_scalar")]
49    pub query_pre_attn_scalar: usize,
50    #[serde(default = "max_position_embeddings")]
51    pub max_position_embeddings: usize,
52    pub quantization_config: Option<QuantizedConfig>,
53    #[serde(default = "use_flash_attn")]
54    pub use_flash_attn: bool,
55    #[serde(default = "tie_word_embeddings")]
56    pub tie_word_embeddings: bool,
57    #[serde(default = "rope_local_base_freq")]
58    pub rope_local_base_freq: f64,
59    #[serde(default = "sliding_window_pattern")]
60    pub sliding_window_pattern: usize,
61    pub rope_scaling: Option<Gemma3RopeScalingConfig>,
62}
63
64#[derive(Debug, Clone, serde::Deserialize)]
65pub enum Gemma3Config {
66    #[serde(untagged)]
67    WithVision {
68        text_config: Gemma3TextConfig,
69        vision_config: SiglipVisionConfig,
70        image_token_index: usize,
71        mm_tokens_per_image: usize,
72    },
73    #[serde(untagged)]
74    Text(Gemma3TextConfig),
75}