mistralrs_core/vision_models/llama4/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod text;
4
5use std::sync::Arc;
6
7use candle_core::{DType, Device, Result, Tensor, D};
8use candle_nn::{Linear, Module};
9use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
10use text::TextModel;
11use vision::Llama4VisionModel;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 device_map::DeviceMapper,
16 layers::linear_no_bias,
17 ops::NonZeroOp,
18 paged_attention::{AttentionImplementation, ModelConfigMetadata},
19 pipeline::{
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
22 },
23 utils::unvarbuilder::UnVarBuilder,
24};
25
26mod config;
27mod inputs_processor;
28mod vision;
29
30pub(crate) use config::{Llama4Config, TextConfig};
31pub(crate) use inputs_processor::{Llama4ImageProcessor, Llama4Processor, IMAGE_TOKEN};
32
33struct Llama4MultiModalProjector {
34 linear_1: Linear,
35}
36
37impl Llama4MultiModalProjector {
38 fn new(cfg: &Llama4Config, vb: ShardedVarBuilder) -> Result<Self> {
39 Ok(Self {
40 linear_1: linear_no_bias(
41 cfg.vision_config.vision_output_dim,
42 cfg.text_config.hidden_size,
43 vb.pp("linear_1"),
44 )?,
45 })
46 }
47
48 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
49 self.linear_1.forward(xs)
50 }
51}
52
53pub struct Llama4Model {
54 language_model: TextModel,
55 vision_model: Llama4VisionModel,
56 multi_modal_projector: Llama4MultiModalProjector,
57 image_token_index: usize,
58}
59
60impl Llama4Model {
61 pub fn new(
62 cfg: &Llama4Config,
63 vb: ShardedVarBuilder,
64 is_gptx: bool,
65 normal_loading_metadata: NormalLoadingMetadata,
66 attention_mechanism: AttentionImplementation,
67 ) -> Result<Self> {
68 let vision_model = Llama4VisionModel::new(
69 &cfg.vision_config,
70 vb.pp("vision_model"),
71 &normal_loading_metadata.real_device,
72 &normal_loading_metadata.mapper.get_comm_for(0)?,
73 &normal_loading_metadata.multi_progress,
74 )?;
75 let multi_modal_projector = Llama4MultiModalProjector::new(
76 cfg,
77 vb.pp("multi_modal_projector")
78 .set_device(normal_loading_metadata.real_device.clone()),
79 )?;
80 let language_model = TextModel::new(
81 &cfg.text_config,
82 vb.pp("language_model"),
83 is_gptx,
84 normal_loading_metadata,
85 attention_mechanism,
86 )?;
87
88 Ok(Self {
89 language_model,
90 vision_model,
91 multi_modal_projector,
92 image_token_index: cfg.image_token_index,
93 })
94 }
95
96 fn forward(
97 &self,
98 input_ids: &Tensor,
99 pixel_values: Option<Tensor>,
100 seqlen_offsets: &[usize],
101 context_lens: Vec<(usize, usize)>,
102 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
103 flash_params: &FlashParams,
104 ) -> Result<Tensor> {
105 let mut input_embeds = self.language_model.get_input_embeddings(input_ids)?;
106
107 if let Some(pixel_values) = pixel_values {
108 let image_features = self.vision_model.forward(&pixel_values)?;
109
110 let vision_flat = image_features.reshape(((), image_features.dim(D::Minus1)?))?;
111 let projected_vision_flat = self.multi_modal_projector.forward(&vision_flat)?;
112
113 let special_image_mask = input_ids
114 .eq(self.image_token_index as f64)?
115 .unsqueeze(D::Minus1)?
116 .broadcast_as(input_embeds.shape().clone())?
117 .to_dtype(DType::U32)?;
118
119 let mask_flat = special_image_mask.flatten_all()?;
120 let mut x_flat = input_embeds.flatten_all()?;
121 let src_flat = projected_vision_flat.flatten_all()?;
122
123 let indices = mask_flat.nonzero()?.squeeze(1)?;
124 let current_vals = x_flat.gather(&indices, 0)?;
125 let diff = (src_flat - current_vals)?;
126 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
127
128 input_embeds = x_flat.reshape(input_embeds.shape())?;
129 }
130
131 self.language_model.forward_embeds(
132 input_ids,
133 input_embeds,
134 seqlen_offsets,
135 context_lens,
136 metadata,
137 flash_params,
138 )
139 }
140}
141
142impl IsqModel for Llama4Model {
143 fn get_layers(
144 &mut self,
145 ) -> (
146 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
147 &dyn DeviceMapper,
148 ) {
149 let (mut layers, device_map) = self.language_model.get_layers();
150 layers.extend(
151 self.vision_model
152 .get_isq_layers()
153 .into_iter()
154 .map(|x| (x, None)),
155 );
156 (layers, device_map)
157 }
158
159 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
160 let uvb = UnVarBuilder::new();
161
162 uvb.pp("multi_modal_projector")
163 .pp("linear_1")
164 .add(&self.multi_modal_projector.linear_1);
165 uvb.pp("language_model")
166 .extend(self.language_model.residual_tensors());
167 uvb.pp("vision_model")
168 .extend(self.vision_model.residual_tensors());
169
170 uvb.to_safetensors()
171 }
172}
173
174pub struct Llama4ModelSpecificArgs;
175
176impl NormalModel for Llama4Model {
177 fn forward(
178 &self,
179 input_ids: &Tensor,
180 seqlen_offsets: &[usize],
181 context_lens: Vec<(usize, usize)>,
182 _position_ids: Vec<usize>,
183 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
184 flash_params: &FlashParams,
185 ) -> candle_core::Result<Tensor> {
186 self.forward(
187 input_ids,
188 None,
189 seqlen_offsets,
190 context_lens,
191 metadata,
192 flash_params,
193 )
194 }
195 fn xlora_forward(
196 &self,
197 _input_ids: &Tensor,
198 _input_ids_full: &Tensor,
199 _seqlen_offsets: &[usize],
200 _seqlen_offsets_full: &[usize],
201 _no_kv_cache: bool,
202 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
203 _context_lens: Vec<(usize, usize)>,
204 _position_ids: Vec<usize>,
205 _flash_params: &FlashParams,
206 _flash_params_full: &FlashParams,
207 ) -> Result<Tensor> {
208 unimplemented!()
209 }
210 fn cache(&self) -> &EitherCache {
211 self.language_model.cache()
212 }
213 fn cache_mut(&mut self) -> &mut EitherCache {
214 self.language_model.cache_mut()
215 }
216 fn config(&self) -> &ModelConfigMetadata {
217 self.language_model.config()
218 }
219 fn is_xlora(&self) -> bool {
220 false
221 }
222 fn device(&self) -> &Device {
223 self.language_model.device()
224 }
225 fn max_seq_len(&self) -> usize {
226 self.language_model.max_seq_len()
227 }
228}
229
230impl VisionModel for Llama4Model {
231 fn forward(
232 &self,
233 input_ids: &Tensor,
234 pixel_values: Option<Tensor>,
235 seqlen_offsets: &[usize],
236 context_lens: Vec<(usize, usize)>,
237 _position_ids: Vec<usize>,
238 model_specific_args: Box<dyn std::any::Any>,
239 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
240 flash_params: &FlashParams,
241 ) -> candle_core::Result<Tensor> {
242 let Llama4ModelSpecificArgs = *model_specific_args
243 .downcast()
244 .expect("Cannot downcast into `Llama4ModelSpecificArgs`");
245 self.forward(
246 input_ids,
247 pixel_values,
248 seqlen_offsets,
249 context_lens,
250 metadata,
251 flash_params,
252 )
253 }
254 fn cache(&self) -> &EitherCache {
255 self.language_model.cache()
256 }
257 fn cache_mut(&mut self) -> &mut EitherCache {
258 self.language_model.cache_mut()
259 }
260 fn config(&self) -> &ModelConfigMetadata {
261 self.language_model.config()
262 }
263 fn has_conv2d(&self) -> bool {
264 false
265 }
266 fn device(&self) -> &Device {
267 self.language_model.device()
268 }
269 fn max_seq_len(&self) -> usize {
270 self.language_model.max_seq_len()
271 }
272 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
273 Box::new(Llama4ModelSpecificArgs)
274 }
275}
276
277impl AnyMoeBaseModelMixin for Llama4Model {}