mistralrs_core/vision_models/gemma3/
config.rs1use mistralrs_quant::QuantizedConfig;
2
3use crate::{
4 layers::{Activation, Gemma3RopeScalingConfig},
5 serde_default_fn,
6 vision_models::siglip::SiglipVisionConfig,
7};
8
9serde_default_fn!(bool, attention_bias, false);
10serde_default_fn!(usize, head_dim, 256);
11serde_default_fn!(Activation, hidden_activation, Activation::GeluPytorchTanh);
12serde_default_fn!(f64, rms_norm_eps, 1e-6);
13serde_default_fn!(f64, rope_theta, 1000000.);
14serde_default_fn!(usize, vocab_size, 262208);
15serde_default_fn!(bool, tie_word_embeddings, true);
16serde_default_fn!(usize, query_pre_attn_scalar, 256);
17serde_default_fn!(usize, max_position_embeddings, 131072);
18serde_default_fn!(bool, use_flash_attn, false);
19serde_default_fn!(f64, rope_local_base_freq, 10000.);
20serde_default_fn!(usize, sliding_window_pattern, 6);
21serde_default_fn!(usize, num_attention_heads, 8);
22serde_default_fn!(usize, num_key_value_heads, 4);
23
24#[derive(Debug, Clone, serde::Deserialize)]
25pub struct Gemma3TextConfig {
26 #[serde(default = "attention_bias")]
27 pub attention_bias: bool,
28 #[serde(default = "head_dim")]
29 pub head_dim: usize,
30 #[serde(default = "hidden_activation")]
31 pub hidden_activation: Activation,
32 pub hidden_size: usize,
33 pub intermediate_size: usize,
34 #[serde(default = "num_attention_heads")]
35 pub num_attention_heads: usize,
36 pub num_hidden_layers: usize,
37 #[serde(default = "num_key_value_heads")]
38 pub num_key_value_heads: usize,
39 #[serde(default = "rms_norm_eps")]
40 pub rms_norm_eps: f64,
41 #[serde(default = "rope_theta")]
42 pub rope_theta: f64,
43 #[serde(default = "vocab_size")]
44 pub vocab_size: usize,
45 pub sliding_window: usize,
46 pub attn_logit_softcapping: Option<f64>,
47 pub final_logit_softcapping: Option<f64>,
48 #[serde(default = "query_pre_attn_scalar")]
49 pub query_pre_attn_scalar: usize,
50 #[serde(default = "max_position_embeddings")]
51 pub max_position_embeddings: usize,
52 pub quantization_config: Option<QuantizedConfig>,
53 #[serde(default = "use_flash_attn")]
54 pub use_flash_attn: bool,
55 #[serde(default = "tie_word_embeddings")]
56 pub tie_word_embeddings: bool,
57 #[serde(default = "rope_local_base_freq")]
58 pub rope_local_base_freq: f64,
59 #[serde(default = "sliding_window_pattern")]
60 pub sliding_window_pattern: usize,
61 pub rope_scaling: Option<Gemma3RopeScalingConfig>,
62}
63
64#[derive(Debug, Clone, serde::Deserialize)]
65pub enum Gemma3Config {
66 #[serde(untagged)]
67 WithVision {
68 text_config: Gemma3TextConfig,
69 vision_config: SiglipVisionConfig,
70 image_token_index: usize,
71 mm_tokens_per_image: usize,
72 },
73 #[serde(untagged)]
74 Text(Gemma3TextConfig),
75}