mistralrs_core/vision_models/llama4/
config.rs1use mistralrs_quant::QuantizedConfig;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5 layers::{Activation, Llama3RopeConfig},
6 serde_default_fn,
7};
8
9serde_default_fn!(bool, word_emb_default, false);
10serde_default_fn!(bool, use_flash_attn_default, false);
11serde_default_fn!(Option<f32>, attn_temperature_tuning, Some(4.));
12serde_default_fn!(Option<f32>, floor_scale, Some(8192.));
13serde_default_fn!(Option<f32>, attn_scale, Some(0.1));
14
15#[derive(Debug, Clone, Deserialize, Serialize, Default)]
16pub struct TextConfig {
17 pub hidden_act: Activation,
18 pub hidden_size: usize,
19 pub intermediate_size: usize,
20 pub vocab_size: usize,
21 pub num_hidden_layers: usize,
22 pub num_attention_heads: usize,
23 pub num_key_value_heads: usize,
24 #[serde(default = "use_flash_attn_default")]
25 pub use_flash_attn: bool,
26 pub rms_norm_eps: f64,
27 pub rope_theta: f32,
28 pub max_position_embeddings: usize,
29 pub rope_scaling: Option<Llama3RopeConfig>,
30 pub quantization_config: Option<QuantizedConfig>,
31 #[serde(default = "word_emb_default")]
32 pub tie_word_embeddings: bool,
33 #[serde(default = "floor_scale")]
34 pub floor_scale: Option<f32>,
35 #[serde(default = "attn_scale")]
36 pub attn_scale: Option<f32>,
37 #[serde(default = "attn_temperature_tuning")]
38 pub attn_temperature_tuning: Option<f32>,
39 pub use_qk_norm: bool,
40 pub moe_layers: Option<Vec<usize>>,
41 pub interleave_moe_layer_step: usize,
42 pub intermediate_size_mlp: usize,
43 pub num_local_experts: usize,
44 pub num_experts_per_tok: usize,
45 pub attention_chunk_size: usize,
46}
47
48impl TextConfig {
49 pub fn moe_layers(&self) -> Vec<usize> {
50 self.moe_layers.clone().unwrap_or(
51 (self.interleave_moe_layer_step - 1..self.num_hidden_layers)
52 .step_by(self.interleave_moe_layer_step)
53 .collect(),
54 )
55 }
56}
57
58#[derive(Debug, Clone, serde::Deserialize)]
59pub enum VisionFeatureSelectStrategy {
60 #[serde(rename = "default")]
61 Default,
62}
63
64#[derive(Debug, Clone, serde::Deserialize)]
65pub struct VisionConfig {
66 pub hidden_size: usize,
67 pub hidden_act: Activation,
68 pub num_hidden_layers: usize,
69 pub num_attention_heads: usize,
70 pub num_channels: usize,
71 pub intermediate_size: usize,
72 pub vision_output_dim: usize,
73 pub image_size: usize,
74 pub patch_size: usize,
75 pub norm_eps: f64,
76 pub pixel_shuffle_ratio: f32,
77 pub projector_input_dim: usize,
78 pub projector_output_dim: usize,
79 pub vision_feature_layer: isize,
80 pub rope_theta: f32,
81}
82
83impl VisionConfig {
84 pub fn num_patches(&self) -> usize {
85 (self.image_size / self.patch_size).pow(2) + 1
86 }
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct Llama4Config {
91 pub text_config: TextConfig,
92 pub vision_config: VisionConfig,
93 pub image_token_index: usize,
94}