mistralrs_core/vision_models/mllama/
config.rs

1use candle_core::{Result, Tensor};
2use candle_nn::Module;
3use mistralrs_quant::QuantizedConfig;
4
5use crate::serde_default_fn;
6
7#[derive(Debug, Clone, Copy, serde::Deserialize)]
8pub(crate) enum VisionActivation {
9    QuickGelu,
10    #[serde(alias = "gelu")]
11    Gelu,
12    #[serde(alias = "gelu_new")]
13    NewGelu,
14    Relu,
15    Silu,
16}
17
18impl Module for VisionActivation {
19    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
20        match self {
21            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
22            Self::Gelu => xs.gelu_erf(),
23            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
24            Self::NewGelu => xs.gelu(),
25            Self::Relu => xs.relu(),
26            Self::Silu => xs.silu(),
27        }
28    }
29}
30
31serde_default_fn!(usize, d_attn_heads, 16);
32
33#[derive(Debug, Clone, serde::Deserialize)]
34pub(crate) struct MLlamaVisionConfig {
35    pub(crate) hidden_size: usize,
36    pub(crate) hidden_act: VisionActivation,
37    pub(crate) num_hidden_layers: usize,
38    pub(crate) num_global_layers: usize,
39    #[serde(default = "d_attn_heads")]
40    pub(crate) num_attention_heads: usize,
41    pub(crate) num_channels: usize,
42    pub(crate) intermediate_size: usize,
43    pub(crate) vision_output_dim: usize,
44    pub(crate) image_size: usize,
45    pub(crate) patch_size: usize,
46    pub(crate) norm_eps: f64,
47    pub(crate) max_num_tiles: usize,
48    pub(crate) intermediate_layers_indices: Vec<usize>,
49    pub(crate) supported_aspect_ratios: Vec<(usize, usize)>,
50}
51
52impl MLlamaVisionConfig {
53    pub(crate) fn max_aspect_ratio_id(&self) -> usize {
54        self.supported_aspect_ratios.len()
55    }
56}
57
58#[derive(Debug, Clone, serde::Deserialize)]
59pub(crate) enum MLlamaRopeType {
60    #[serde(rename = "default")]
61    Default,
62    #[serde(rename = "linear")]
63    Linear,
64    #[serde(rename = "dynamic")]
65    Dynamic,
66    #[serde(rename = "yarn")]
67    Yarn,
68    #[serde(rename = "longrope")]
69    Longrope,
70    #[serde(rename = "llama3")]
71    Llama3,
72}
73
74#[derive(Debug, Clone, serde::Deserialize)]
75#[allow(dead_code)]
76pub(crate) struct MLlamaRopeScaling {
77    pub(crate) rope_type: MLlamaRopeType,
78    pub(crate) factor: Option<f32>,
79    pub(crate) original_max_position_embeddings: usize,
80    pub(crate) attention_factor: Option<f32>,
81    pub(crate) beta_fast: Option<f32>,
82    pub(crate) beta_slow: Option<f32>,
83    pub(crate) short_factor: Option<Vec<f64>>,
84    pub(crate) long_factor: Option<Vec<f64>>,
85    pub(crate) low_freq_factor: Option<f32>,
86    pub(crate) high_freq_factor: Option<f32>,
87}
88
89serde_default_fn!(bool, d_flash_attn, false);
90
91#[derive(Debug, Clone, serde::Deserialize)]
92pub struct MLlamaTextConfig {
93    pub(crate) rope_scaling: Option<MLlamaRopeScaling>,
94    pub(crate) vocab_size: usize,
95    pub(crate) hidden_size: usize,
96    pub(crate) hidden_act: candle_nn::Activation,
97    pub(crate) num_hidden_layers: usize,
98    pub(crate) num_attention_heads: usize,
99    pub(crate) num_key_value_heads: usize,
100    pub(crate) intermediate_size: usize,
101    pub(crate) rope_theta: f32,
102    pub(crate) rms_norm_eps: f64,
103    pub(crate) max_position_embeddings: usize,
104    pub(crate) tie_word_embeddings: bool,
105    pub(crate) cross_attention_layers: Vec<usize>,
106    #[serde(default = "d_flash_attn")]
107    pub(crate) use_flash_attn: bool,
108    pub(crate) quantization_config: Option<QuantizedConfig>,
109}
110
111impl MLlamaTextConfig {
112    pub(crate) fn head_dim(&self) -> usize {
113        self.hidden_size / self.num_attention_heads
114    }
115}
116
117#[derive(Debug, Clone, serde::Deserialize)]
118pub(crate) struct MLlamaConfig {
119    pub(crate) vision_config: MLlamaVisionConfig,
120    pub(crate) text_config: MLlamaTextConfig,
121}