mistralrs_core/vision_models/mllama/
config.rs1use candle_core::{Result, Tensor};
2use candle_nn::Module;
3use mistralrs_quant::QuantizedConfig;
4
5use crate::serde_default_fn;
6
7#[derive(Debug, Clone, Copy, serde::Deserialize)]
8pub(crate) enum VisionActivation {
9 QuickGelu,
10 #[serde(alias = "gelu")]
11 Gelu,
12 #[serde(alias = "gelu_new")]
13 NewGelu,
14 Relu,
15 Silu,
16}
17
18impl Module for VisionActivation {
19 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
20 match self {
21 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
22 Self::Gelu => xs.gelu_erf(),
23 Self::NewGelu => xs.gelu(),
25 Self::Relu => xs.relu(),
26 Self::Silu => xs.silu(),
27 }
28 }
29}
30
31serde_default_fn!(usize, d_attn_heads, 16);
32
33#[derive(Debug, Clone, serde::Deserialize)]
34pub(crate) struct MLlamaVisionConfig {
35 pub(crate) hidden_size: usize,
36 pub(crate) hidden_act: VisionActivation,
37 pub(crate) num_hidden_layers: usize,
38 pub(crate) num_global_layers: usize,
39 #[serde(default = "d_attn_heads")]
40 pub(crate) num_attention_heads: usize,
41 pub(crate) num_channels: usize,
42 pub(crate) intermediate_size: usize,
43 pub(crate) vision_output_dim: usize,
44 pub(crate) image_size: usize,
45 pub(crate) patch_size: usize,
46 pub(crate) norm_eps: f64,
47 pub(crate) max_num_tiles: usize,
48 pub(crate) intermediate_layers_indices: Vec<usize>,
49 pub(crate) supported_aspect_ratios: Vec<(usize, usize)>,
50}
51
52impl MLlamaVisionConfig {
53 pub(crate) fn max_aspect_ratio_id(&self) -> usize {
54 self.supported_aspect_ratios.len()
55 }
56}
57
58#[derive(Debug, Clone, serde::Deserialize)]
59pub(crate) enum MLlamaRopeType {
60 #[serde(rename = "default")]
61 Default,
62 #[serde(rename = "linear")]
63 Linear,
64 #[serde(rename = "dynamic")]
65 Dynamic,
66 #[serde(rename = "yarn")]
67 Yarn,
68 #[serde(rename = "longrope")]
69 Longrope,
70 #[serde(rename = "llama3")]
71 Llama3,
72}
73
74#[derive(Debug, Clone, serde::Deserialize)]
75#[allow(dead_code)]
76pub(crate) struct MLlamaRopeScaling {
77 pub(crate) rope_type: MLlamaRopeType,
78 pub(crate) factor: Option<f32>,
79 pub(crate) original_max_position_embeddings: usize,
80 pub(crate) attention_factor: Option<f32>,
81 pub(crate) beta_fast: Option<f32>,
82 pub(crate) beta_slow: Option<f32>,
83 pub(crate) short_factor: Option<Vec<f64>>,
84 pub(crate) long_factor: Option<Vec<f64>>,
85 pub(crate) low_freq_factor: Option<f32>,
86 pub(crate) high_freq_factor: Option<f32>,
87}
88
89serde_default_fn!(bool, d_flash_attn, false);
90
91#[derive(Debug, Clone, serde::Deserialize)]
92pub struct MLlamaTextConfig {
93 pub(crate) rope_scaling: Option<MLlamaRopeScaling>,
94 pub(crate) vocab_size: usize,
95 pub(crate) hidden_size: usize,
96 pub(crate) hidden_act: candle_nn::Activation,
97 pub(crate) num_hidden_layers: usize,
98 pub(crate) num_attention_heads: usize,
99 pub(crate) num_key_value_heads: usize,
100 pub(crate) intermediate_size: usize,
101 pub(crate) rope_theta: f32,
102 pub(crate) rms_norm_eps: f64,
103 pub(crate) max_position_embeddings: usize,
104 pub(crate) tie_word_embeddings: bool,
105 pub(crate) cross_attention_layers: Vec<usize>,
106 #[serde(default = "d_flash_attn")]
107 pub(crate) use_flash_attn: bool,
108 pub(crate) quantization_config: Option<QuantizedConfig>,
109}
110
111impl MLlamaTextConfig {
112 pub(crate) fn head_dim(&self) -> usize {
113 self.hidden_size / self.num_attention_heads
114 }
115}
116
117#[derive(Debug, Clone, serde::Deserialize)]
118pub(crate) struct MLlamaConfig {
119 pub(crate) vision_config: MLlamaVisionConfig,
120 pub(crate) text_config: MLlamaTextConfig,
121}