mistralrs_core/vision_models/llama4/
config.rs

1use mistralrs_quant::QuantizedConfig;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    layers::{Activation, Llama3RopeConfig},
6    serde_default_fn,
7};
8
9serde_default_fn!(bool, word_emb_default, false);
10serde_default_fn!(bool, use_flash_attn_default, false);
11serde_default_fn!(Option<f32>, attn_temperature_tuning, Some(4.));
12serde_default_fn!(Option<f32>, floor_scale, Some(8192.));
13serde_default_fn!(Option<f32>, attn_scale, Some(0.1));
14
15#[derive(Debug, Clone, Deserialize, Serialize, Default)]
16pub struct TextConfig {
17    pub hidden_act: Activation,
18    pub hidden_size: usize,
19    pub intermediate_size: usize,
20    pub vocab_size: usize,
21    pub num_hidden_layers: usize,
22    pub num_attention_heads: usize,
23    pub num_key_value_heads: usize,
24    #[serde(default = "use_flash_attn_default")]
25    pub use_flash_attn: bool,
26    pub rms_norm_eps: f64,
27    pub rope_theta: f32,
28    pub max_position_embeddings: usize,
29    pub rope_scaling: Option<Llama3RopeConfig>,
30    pub quantization_config: Option<QuantizedConfig>,
31    #[serde(default = "word_emb_default")]
32    pub tie_word_embeddings: bool,
33    #[serde(default = "floor_scale")]
34    pub floor_scale: Option<f32>,
35    #[serde(default = "attn_scale")]
36    pub attn_scale: Option<f32>,
37    #[serde(default = "attn_temperature_tuning")]
38    pub attn_temperature_tuning: Option<f32>,
39    pub use_qk_norm: bool,
40    pub moe_layers: Option<Vec<usize>>,
41    pub interleave_moe_layer_step: usize,
42    pub intermediate_size_mlp: usize,
43    pub num_local_experts: usize,
44    pub num_experts_per_tok: usize,
45    pub attention_chunk_size: usize,
46}
47
48impl TextConfig {
49    pub fn moe_layers(&self) -> Vec<usize> {
50        self.moe_layers.clone().unwrap_or(
51            (self.interleave_moe_layer_step - 1..self.num_hidden_layers)
52                .step_by(self.interleave_moe_layer_step)
53                .collect(),
54        )
55    }
56}
57
58#[derive(Debug, Clone, serde::Deserialize)]
59pub enum VisionFeatureSelectStrategy {
60    #[serde(rename = "default")]
61    Default,
62}
63
64#[derive(Debug, Clone, serde::Deserialize)]
65pub struct VisionConfig {
66    pub hidden_size: usize,
67    pub hidden_act: Activation,
68    pub num_hidden_layers: usize,
69    pub num_attention_heads: usize,
70    pub num_channels: usize,
71    pub intermediate_size: usize,
72    pub vision_output_dim: usize,
73    pub image_size: usize,
74    pub patch_size: usize,
75    pub norm_eps: f64,
76    pub pixel_shuffle_ratio: f32,
77    pub projector_input_dim: usize,
78    pub projector_output_dim: usize,
79    pub vision_feature_layer: isize,
80    pub rope_theta: f32,
81}
82
83impl VisionConfig {
84    pub fn num_patches(&self) -> usize {
85        (self.image_size / self.patch_size).pow(2) + 1
86    }
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct Llama4Config {
91    pub text_config: TextConfig,
92    pub vision_config: VisionConfig,
93    pub image_token_index: usize,
94}